# -*- coding: utf-8 -*-
"""
Created on Thu Oct 31 17:45:02 2024

@author: Yunhong Che
"""
import gzip
import json
from beep import structure
import os
import numpy as np
from numpy import gradient
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
import random
import seaborn as sns
from matplotlib import rcParams
import pandas as pd
from sklearn import metrics
from matplotlib.colors import Normalize
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import matplotlib.colors as mcolors
import time
import re
import shap
from pyswarm import pso
from scipy.stats import gaussian_kde
from scipy.signal import savgol_filter
from sklearn.model_selection import KFold
from scipy.stats import linregress
from sklearn.linear_model import Lasso
from matplotlib.lines import Line2D
from scipy.optimize import minimize
from scipy.optimize import differential_evolution
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from xgboost import XGBRegressor
from sklearn.model_selection import GridSearchCV
from joblib import Parallel, delayed
from sklearn.preprocessing import StandardScaler
from numpy.linalg import matrix_rank
from matplotlib.colors import LinearSegmentedColormap
from sklearn.linear_model import MultiTaskLasso
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import pickle
from scipy.spatial import distance
from matplotlib.cm import ScalarMappable
import joblib
import warnings
from scipy.signal import find_peaks
warnings.filterwarnings("ignore")

def set_random_seed(seed):
    # seed setting
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

font_properties = {'family': 'Times New Roman', 'size': 12}
rcParams.update({'font.family': font_properties['family'], 'font.size': font_properties['size']})

#%% Load half cell data and fitting function construction
with gzip.open('..\\''data_all_cells.json.gz', 'rt', encoding='utf-8') as f:
    data_all_cells = json.load(f)
    
all_batteries = list(data_all_cells.keys())

nominal_capacity = 4.84

# c/5 
OCPn_data = pd.read_csv('anode_SiO_Gr_discharge_Cover5_smoothed_dvdq_JS.csv')
OCPp_data = pd.read_csv('cathode_NCA_discharge_Cover5_smoothed_dvdq_JS.csv')

OCPn_SOC = OCPn_data['SOC_linspace'].values
OCPn_V = OCPn_data['Voltage'].values
OCPp_SOC = OCPp_data['SOC_linspace'].values
OCPp_V = OCPp_data['Voltage'].values[::-1].copy()

# 
OCP_p = interp1d(OCPp_SOC, OCPp_V, kind='cubic', fill_value='extrapolate', bounds_error=False)
OCP_n = interp1d(OCPn_SOC, OCPn_V, kind='cubic', fill_value='extrapolate', bounds_error=False)


# c/40 
OCPn_data_40 = pd.read_csv('anode_SiO_Gr_discharge_Cover40_smooth_JS.csv')
OCPp_data_40 = pd.read_csv('cathode_NCA_discharge_Cover40_smooth_JS.csv')

OCPn_SOC_40 = OCPn_data_40['SOC_linspace'].values
OCPn_V_40 = OCPn_data_40['Voltage'].values
OCPp_SOC_40 = OCPp_data_40['SOC_linspace'].values
OCPp_V_40 = OCPp_data_40['Voltage'].values[::-1].copy()

# 
OCP_p_40 = interp1d(OCPp_SOC_40, OCPp_V_40, kind='cubic', fill_value='extrapolate', bounds_error=False)
OCP_n_40 = interp1d(OCPn_SOC_40, OCPn_V_40, kind='cubic', fill_value='extrapolate', bounds_error=False)


all_batteries = list(data_all_cells.keys())
# Split the battery data into 70% training set and 30% test set
# train_batteries, test_batteries = train_test_split(all_batteries, test_size=0.3, random_state=123)

all_battery_num = [re.search(r'_(\d+)_', entry).group(1) for entry in all_batteries]
all_battery_num_stripped = [num.lstrip('0') for num in all_battery_num]
with open('train_cells.txt', 'r') as file:
    train_cells_content = file.read()

train_cells = train_cells_content.splitlines()
train_batteries_indices = [all_battery_num_stripped.index(cell) for cell in train_cells if cell in all_battery_num_stripped]
train_batteries = [all_batteries[i] for i in train_batteries_indices]

with open('test_cells.txt', 'r') as file:
    test_cells_content = file.read()

test_cells = test_cells_content.splitlines()
test_batteries_indices = [all_battery_num_stripped.index(cell) for cell in test_cells if cell in all_battery_num_stripped]
test_batteries = [all_batteries[i] for i in test_batteries_indices]

#%% train prediction C/5 ----- C/40 model

filename = f"saved_fittings/resval_extract_data_DE_DOF3_eucl.npz"
print(filename)
norminal_c = 4.84
data = np.load(filename, allow_pickle=True)
all_Cq = data['all_Cq']*norminal_c
all_Cp_opt = data['all_Cp_opt']*norminal_c
all_Cn_opt = data['all_Cn_opt']*norminal_c
all_x0_opt = data['all_x0_opt']
all_y0_opt = data['all_y0_opt']

data_dict = {key: data[key] for key in data.files}
# print("Keys in the file:", data.files)
all_cells = data['all_cells']  
cell_names = all_cells[:, 0]  
rate_labels = all_cells[:, 1]  
unique_labels = np.unique(rate_labels)

split_data_dict = {label: {} for label in unique_labels}

for key in data_dict:
    if key!='time_consum':
        if data_dict[key].shape[0] == len(rate_labels):  
            for label in unique_labels:
                split_data_dict[label][key] = data_dict[key][rate_labels == label]
        else:
            for label in unique_labels:
                split_data_dict[label][key] = data_dict[key]

# for label in unique_labels:
#     print(f"Data for rate {label}: keys -> {split_data_dict[label].keys()}")

data_sets = []
# ['C/40', 'C/5']
assert len(unique_labels) == 2, "C/5 and C/40, need to change to suit more C-rates"

label1, label2 = unique_labels  
subset1, subset2 = split_data_dict[label1], split_data_dict[label2]

data_sets = [
    (subset1['all_Cq']*norminal_c, subset2['all_Cq']*norminal_c, f'C {label1}', f'C {label2}'),
    (subset1['all_Cp_opt']*norminal_c, subset2['all_Cp_opt']*norminal_c, f'Cp {label1}', f'Cp {label2}'),
    (subset1['all_Cn_opt']*norminal_c, subset2['all_Cn_opt']*norminal_c, f'Cn {label1}', f'Cn {label2}'),
    (subset1['all_Cn_opt'] * subset1['all_x0_opt']*norminal_c + subset1['all_Cp_opt'] * subset1['all_y0_opt']*norminal_c, 
     subset2['all_Cn_opt'] * subset2['all_x0_opt']*norminal_c + subset2['all_Cp_opt'] * subset2['all_y0_opt']*norminal_c, 
     f'Cli {label1}', f'Cli {label2}')
]


np.random.seed(123)
alphas = np.logspace(-2, 2, 50)
kf = KFold(n_splits=5, shuffle=True, random_state=42)
X = np.column_stack([
    data_sets[0][1],  # Cq @ C/5
    data_sets[1][1],  # Cp @ C/5
    data_sets[2][1],  # Cn @ C/5
    data_sets[3][1],  # Cli @ C/5
])
Y = np.column_stack([
    data_sets[0][0],  # Cq @ C/40
    data_sets[1][0],  # Cp @ C/40
    data_sets[2][0],  # Cn @ C/40
    data_sets[3][0],  # Cli @ C/40
])

colors = ['#DB6C6E', '#7BABD2', '#B3C786', '#B283B9','#F4BA61','#D5CA80','#9593C3']

# output_names = ["Cq", "Cp", "Cn", "Qli"]
output_names =[r'${\mathrm{C_q}}$',r'${\mathrm{C_p}}$',r'${\mathrm{C_n}}$',r'${\mathrm{Q_{li}}}$']
alphas = np.logspace(-2, 2, 50)

best_alphas = []

for train_index, val_index in kf.split(X):
    X_train, X_val = X[train_index], X[val_index]
    Y_train, Y_val = Y[train_index], Y[val_index]
    
    scaler_X = StandardScaler().fit(X_train)
    scaler_Y = StandardScaler().fit(Y_train)
    X_train_scaled = scaler_X.transform(X_train)
    X_val_scaled = scaler_X.transform(X_val)
    Y_train_scaled = scaler_Y.transform(Y_train)
    Y_val_scaled = scaler_Y.transform(Y_val)

    best_alpha, best_score = None, float('inf')
    for alpha in alphas:
        model = MultiTaskLasso(alpha=alpha)
        model.fit(X_train_scaled, Y_train_scaled)
        Y_val_pred = scaler_Y.inverse_transform(model.predict(X_val_scaled))
        Y_val_true = scaler_Y.inverse_transform(Y_val_scaled)
        rmse = np.sqrt(((Y_val_pred - Y_val_true) ** 2).mean())
        if rmse < best_score:
            best_alpha = alpha
            best_score = rmse
    
    best_alphas.append(best_alpha)
    print(f"Fold best alpha: {best_alpha:.4f}, RMSE: {best_score:.4f}")


final_alpha = np.mean(best_alphas)
print(f"\nFinal averaged alpha from 5 folds: {final_alpha:.4f}")


scaler_X_full = StandardScaler().fit(X)
scaler_Y_full = StandardScaler().fit(Y)
X_scaled = scaler_X_full.transform(X)
Y_scaled = scaler_Y_full.transform(Y)

final_model = MultiTaskLasso(alpha=final_alpha)
final_model.fit(X_scaled, Y_scaled)


print("Final model coefficients:")
print(final_model.coef_)
joblib.dump(final_model, 'saved_fittings/'+'electrode_C5_to_C40.pkl')


data = np.load(filename, allow_pickle=True)
all_Cp_opt = data["all_Cp_opt"]
all_Cn_opt = data["all_Cn_opt"]
all_x0_opt = data["all_x0_opt"]
all_y0_opt = data["all_y0_opt"]
all_Cq = data["all_Cq"]
all_OCV_fit = data["all_OCV_fit"]
all_cell_cap = data["all_cell_cap"]
all_cell_ocv = data["all_cell_ocv"]
all_cell_vmea = data["all_cell_vmea"]
all_cells = data["all_cells"]

all_Cli = np.array([
    np.array(cp) * np.array(y0) + np.array(x0) * np.array(cn)
    for cp, y0, x0, cn in zip(all_Cp_opt, all_y0_opt, all_x0_opt, all_Cn_opt)
])


all_predictions = []  

for i in range(94,len(all_Cp_opt)):
    
    X_new = np.column_stack([
        np.array(all_Cq[i]) * norminal_c,
        np.array(all_Cp_opt[i]) * norminal_c,
        np.array(all_Cn_opt[i]) * norminal_c,
        np.array(all_Cli[i]) * norminal_c,
    ])
    X_new_scaled = scaler_X_full.transform(X_new)
    Y_new_scaled_pred = final_model.predict(X_new_scaled)
    Y_new_pred = scaler_Y_full.inverse_transform(Y_new_scaled_pred)
    all_predictions.append(Y_new_pred)  # shape: (n_RPT, 4)

all_cell_ocv_construct = 0*all_cell_ocv[0:94]
for idx in range(len(all_cell_ocv_construct)):
    Cq = all_predictions[idx][0][0] /nominal_capacity
    Cp = all_predictions[idx][0][1] /nominal_capacity
    Cn = all_predictions[idx][0][2] /nominal_capacity
    Cli = all_predictions[idx][0][3] /nominal_capacity
    y0 = 0
    x0 = Cli / Cn
    measured_Q = np.linspace(0, Cq, 1000)
    
    SOC_p = y0 + measured_Q / Cp
    SOC_n = x0 - measured_Q / Cn
    
    Up = OCP_p_40(SOC_p)
    Un = OCP_n_40(SOC_n)
    Voc_fit = Up - Un 
    all_cell_ocv_construct[idx,:,0]=Voc_fit
    all_cell_ocv_construct[idx,:,1]=measured_Q

all_cell_ocv_construct_resval = all_cell_ocv_construct
all_predictions_resval = all_predictions
#%% only needed for the first run to extract data

def extract_and_interpolate(data, num_points=1000):
    """
    """
    nominal_capacity = 4.84
    inputs = []
    outputs = []
    efcs = []
    # print(data['protocol'])
    for rpt_data in data['rpt']:
        if '02C' in rpt_data:
            x_new = np.linspace(0, 1, num_points)
            # Extract the 0.2C discharge process (capacity and voltage) and perform interpolation
            discharge_capacity = np.array(rpt_data['02C']['discharge_capacity'])
            discharge_voltage = np.array(rpt_data['02C']['voltage'])
            EFCs = np.array(rpt_data['02C']['EFC'])
            # print(EFCs)
            I = np.array(rpt_data['02C']['current'])
            dis_start_idx = np.where(discharge_capacity <=0)[0]
            if len(dis_start_idx) > 0:
                dis_start_idx = dis_start_idx[-1] 
            else:
                dis_start_idx = len(discharge_capacity)  
            
            discharge_capacity = discharge_capacity[dis_start_idx:]/nominal_capacity
            # print(max(discharge_capacity))
            discharge_voltage = discharge_voltage[dis_start_idx:]/ 4.2
            
            if  len(discharge_capacity)<50:
                continue 
            
            interp_ocv_v = interp1d(np.linspace(0, 1, len(discharge_voltage)), discharge_voltage, kind='linear')
            interp_ocv_q = interp1d(np.linspace(0, 1, len(discharge_capacity)), discharge_capacity, kind='linear')
            discharge_voltage_interp = interp_ocv_v(x_new)
            discharge_capacity_interp = interp_ocv_q(x_new)
            discharge_capacity_diff = discharge_capacity_interp[1:] - discharge_capacity_interp[ :-1]
            if max(discharge_capacity_diff)>0.05 or max(discharge_capacity)<0.7: 
                # print('skip')
                continue
            
            input_data = np.stack([discharge_voltage_interp, discharge_capacity_interp, I[dis_start_idx+50]*np.ones(len(x_new))/4.84], axis=-1)
            output_data = np.stack([discharge_voltage_interp, discharge_capacity_interp], axis=-1)
            efc_data = np.mean(EFCs)
            # Add the processed sample to the list
            inputs.append(input_data)
            outputs.append(output_data)    
            efcs.append(efc_data)    
    # Convert input and output lists to NumPy arrays
    inputs = np.array(inputs)  # Shape (num_samples, num_points, num_features)
    outputs = np.array(outputs)  #(num_samples, num_points)
    efcs = np.array(efcs)
    return inputs, outputs, efcs


all_inputs, all_outputs = [], []   
all_cells, all_efcs = [], []
for i in range(0,len(all_batteries)): # train_batteries all_batteries 
    battery = all_batteries[i]
    print("read cell", battery)
    cap = []
    inputs_cell, outputs_cell, efcs_cell = extract_and_interpolate(data_all_cells[battery])
    all_inputs.append(inputs_cell)
    all_outputs.append(outputs_cell)
    all_cells.append((battery,"C/5_Cycle"))
    all_efcs.append(efcs_cell)

#%%
def objective_function(params, c_rate, measure_Q, measure_V):
    Cp, Cn, NP_offset = params
    y0 = 0
    x0 = NP_offset
    SOC_p = y0 + measure_Q / Cp
    SOC_n = x0 - measure_Q / Cn
    if c_rate == 'C/40_Cycle':
        Up = OCP_p_40(SOC_p)
        Un = OCP_n_40(SOC_n)
    elif c_rate == 'C/5_Cycle':
        Up = OCP_p(SOC_p)
        Un = OCP_n(SOC_n)
    else:
        raise ValueError("Unsupported c_rate type.")

    fitted_Voc = Up - Un
    regularization = 0.01 * (Cp**2 + Cn**2 + NP_offset**2)
    
    measure_V_matrix = np.vstack((measure_Q, measure_V)).T  # (Q, V) measure
    fitted_Voc_matrix = np.vstack((measure_Q, fitted_Voc)).T  # (Q, V) fit
    error_matrix = distance.cdist(measure_V_matrix, fitted_Voc_matrix, "euclidean")
    error_vector = error_matrix.min(axis=1)
    error = error_vector.mean()
    total_loss = error+regularization
    
    return total_loss


def get_bounds_by_dof(dof):
    if dof == 4:
        lb = [0.2, 0.2, 0, 0]     # Cp, Cn, x0, y0
        ub = [1.2, 1.2, 1.0, 1.0]
        names = ['Cp', 'Cn', 'x0', 'y0']
    elif dof == 3:
        lb = [0.2, 0.2, 0]        # Cp, Cn, NP_offset
        ub = [1.2, 1.2, 1.0]
        names = ['Cp', 'Cn', 'NP_offset']
    elif dof == 2:
        lb = [0.2, 0]             # NP_ratio, NP_offset
        ub = [1.2, 1.0]
        names = ['NP_ratio', 'NP_offset']
    else:
        raise ValueError("DOF must be 2, 3 or 4")
    return lb, ub, names


def run_de(trial, Q, V, rate):
    bounds = list(zip(*get_bounds_by_dof(3)[:2]))
    np.random.seed(trial)
    result = differential_evolution(
        objective_function, bounds,
        args=(rate, Q, V),
        popsize=20, maxiter=200,
        # init="latinhypercube",  #
        # mutation=0.8,
        # recombination=0.7,
          # tol=1e-5,
        seed=trial
    )
    return result.x, result.fun


def optimize_cycle(cycle_index, cell_inputs, cell_outputs, cell_efcs, cell_rate, cell_name):
    measure_Q = cell_inputs[cycle_index, :, 1]
    measure_V = cell_outputs[cycle_index, :, 0] * 4.2
    num_trials = 5  
    results = Parallel(n_jobs=5)(delayed(run_de)(trial, measure_Q, measure_V, cell_rate) 
                                 for trial in range(num_trials))
    best_params = None
    best_fopt = float('inf')
    
    for optimized_params, fopt in results:
        if fopt < best_fopt:
            best_fopt = fopt
            best_params = optimized_params
    
    best_params = [float(p) for p in best_params]
    best_fopt = float(best_fopt)
    
    Cp_opt, Cn_opt, x0_opt, y0_opt = None, None, None, None
    Cp_opt, Cn_opt, NP_offset = best_params
    y0_opt = 0
    x0_opt = NP_offset
    
    SOC_p_fit = y0_opt + measure_Q / Cp_opt
    SOC_n_fit = x0_opt - measure_Q / Cn_opt
 
    if cell_rate == 'C/5_Cycle':
        Up_fit = OCP_p(SOC_p_fit)
        Un_fit = OCP_n(SOC_n_fit)
    elif cell_rate == 'C/40_Cycle':
        Up_fit = OCP_p_40(SOC_p_fit)
        Un_fit = OCP_n_40(SOC_n_fit)
    else:
        raise ValueError("Unsupported cell_rate")
 
    fitted_Voc = Up_fit - Un_fit

    return {
        'Cp_opt': Cp_opt,
        'Cn_opt': Cn_opt,
        'x0_opt': x0_opt,
        'y0_opt': y0_opt,
        'fitted_Voc': fitted_Voc,
        'cell_cap': cell_inputs[cycle_index, -1, 1],
        'Cq': measure_Q[-1]
    }


def optimize_cell(cell_index):
    print(f"Processing cell {cell_index + 1}/{len(all_outputs)}: {all_cells[cell_index]}")
    cell_inputs = all_inputs[cell_index]
    cell_outputs = all_outputs[cell_index]
    cell_rate = all_cells[cell_index][1]
    cell_name = all_cells[cell_index][0]
    cell_efcs = all_efcs[cell_index]
    
    results = Parallel(n_jobs=56)(delayed(optimize_cycle)(i, cell_inputs, cell_outputs,cell_efcs, cell_rate, cell_name) 
                                  for i in range(len(cell_inputs)))

    cell_Cp_opt = [result['Cp_opt'] for result in results]
    cell_Cn_opt = [result['Cn_opt'] for result in results]
    cell_x0_opt = [result['x0_opt'] for result in results]
    cell_y0_opt = [result['y0_opt'] for result in results]
    cell_OCV_fit = [result['fitted_Voc'] for result in results]
    cell_cell_cap = [result['cell_cap'] for result in results]
    cell_Cq = [result['Cq'] for result in results]

    return cell_Cp_opt, cell_Cn_opt, cell_x0_opt, cell_y0_opt, cell_OCV_fit, cell_cell_cap, cell_Cq


all_Cp_opt, all_Cn_opt, all_x0_opt, all_y0_opt = [], [], [], []
all_OCV_fit = []
all_cell_cap = []
all_Cq = []

results = Parallel(n_jobs=56)(delayed(optimize_cell)(c_idx) for c_idx in range(len(all_outputs)))
for result in results:
    cell_Cp_opt, cell_Cn_opt, cell_x0_opt, cell_y0_opt, cell_OCV_fit, cell_cell_cap, cell_Cq = result
    all_Cp_opt.append(cell_Cp_opt)
    all_Cn_opt.append(cell_Cn_opt)
    all_x0_opt.append(cell_x0_opt)
    all_y0_opt.append(cell_y0_opt)
    all_OCV_fit.append(cell_OCV_fit)
    all_cell_cap.append(cell_cell_cap)
    all_Cq.append(cell_Cq)


# filename = "dynamic_extract_data_pso1.npz"
# np.savez(
#     filename,
#     all_Cp_opt=all_Cp_opt,
#     all_Cn_opt=all_Cn_opt,
#     all_x0_opt=all_x0_opt,
#     all_y0_opt=all_y0_opt,
#     all_Cq=all_Cq,
#     all_OCV_fit=all_OCV_fit,
#     all_cell_cap=all_cell_cap,
#     all_cell_ocv=all_outputs,
#     all_cell_vmea=all_inputs,
#     all_cells=all_cells,
#     all_efcs=all_efcs
# )

# print(f"Data saved as {filename} in the current directory.")


data = {
    "all_Cp_opt": all_Cp_opt,
    "all_Cn_opt": all_Cn_opt,
    "all_x0_opt": all_x0_opt,
    "all_y0_opt": all_y0_opt,
    "all_Cq": all_Cq,
    "all_OCV_fit": all_OCV_fit,
    "all_cell_cap": all_cell_cap,
    "all_cell_ocv": all_outputs,
    "all_cell_vmea": all_inputs,
    "all_cells": all_cells,
    "all_efcs": all_efcs
}

with open("tesla_extract_data_pso.pkl", "wb") as f:
    pickle.dump(data, f)

print("data saved as: tesla_extract_data_pso1.pkl")

#%%
norminal_c = 4.84
colors = ['#DB6C6E', '#7BABD2', '#B3C786', '#B283B9','#F4BA61','#D5CA80','#9593C3']

with open("tesla_extract_data_pso.pkl", "rb") as f:
    data = pickle.load(f)
    
plt.rcParams['font.family'] = 'Times New Roman'
rcParams['mathtext.fontset'] = 'custom'
rcParams['mathtext.rm'] = 'Times New Roman'
rcParams['mathtext.it'] = 'Times New Roman:italic'
rcParams['mathtext.bf'] = 'Times New Roman:bold'
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['axes.edgecolor'] = 'black'
plt.rcParams['axes.linewidth'] = 0.8
plt.rcParams['font.size'] = 12

all_Cp_opt = data["all_Cp_opt"]
all_Cn_opt = data["all_Cn_opt"]
all_x0_opt = data["all_x0_opt"]
all_y0_opt = data["all_y0_opt"]
all_Cq = data["all_Cq"]
all_OCV_fit = data["all_OCV_fit"]
all_cell_cap = data["all_cell_cap"]
all_cell_ocv = data["all_cell_ocv"]
all_cell_vmea = data["all_cell_vmea"]
all_cells = data["all_cells"]
all_efcs = data["all_efcs"]
all_Cli = [
    np.array(cp) * np.array(y0) + np.array(x0) * np.array(cn)
    for cp, y0, x0, cn in zip(all_Cp_opt, all_y0_opt, all_x0_opt, all_Cn_opt)
]

fig, axs = plt.subplots(1, 4, figsize=(26 / 2.54, 5.5 / 2.54), dpi=600, gridspec_kw={'wspace': 0.45})
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=True, top=False, left=True, right=False)
# color_map = plt.colormaps.get_cmap("coolwarm")
for i in range(len(all_cell_cap)):
    axs[3].plot(all_efcs[i],np.array(all_Cq[i]) * norminal_c,'-D',markersize=2, alpha=0.5, color=colors[3])
    axs[0].plot(all_efcs[i],np.array(all_Cp_opt[i]) * norminal_c,'-o',markersize=2, alpha=0.5, color=colors[0])
    axs[1].plot(all_efcs[i],np.array(all_Cn_opt[i]) * norminal_c,'-^',markersize=2, alpha=0.5, color=colors[4])
    axs[2].plot(all_efcs[i],np.array(all_Cli[i]) * norminal_c,'-s',markersize=2, alpha=0.5, color=colors[1])
    
axs[0].set_xlabel('EFC [Cycles]')
axs[0].set_ylabel(r'${\mathrm{C_p}}$ [Ah]')
axs[1].set_xlabel('EFC [Cycles]')
axs[1].set_ylabel(r'${\mathrm{C_n}}$ [Ah]')
axs[2].set_xlabel('EFC [Cycles]')
axs[2].set_ylabel(r'${\mathrm{Q_{li}}}$ [Ah]')     
axs[3].set_xlabel('EFC [Cycles]')
axs[3].set_ylabel(r'${\mathrm{C_{q}}}$ [Ah]')                   
plt.tight_layout()
plt.show()


#%%
all_predictions = []  
final_model=joblib.load('saved_fittings/'+'electrode_C5_to_C40.pkl')
for i in range(len(all_Cp_opt)):
   
    X_new = np.column_stack([
        np.array(all_Cq[i]) * norminal_c,
        np.array(all_Cp_opt[i]) * norminal_c,
        np.array(all_Cn_opt[i]) * norminal_c,
        np.array(all_Cli[i]) * norminal_c,
    ])

    X_new_scaled = scaler_X_full.transform(X_new)
    Y_new_scaled_pred = final_model.predict(X_new_scaled)
    Y_new_pred = scaler_Y_full.inverse_transform(Y_new_scaled_pred)
    all_predictions.append(Y_new_pred)  # shape: (n_RPT, 4)

fig, axs = plt.subplots(1, 4, figsize=(26 / 2.54, 5.5 / 2.54), dpi=600, gridspec_kw={'wspace': 0.45})
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top=True, right=True, which='both')
plt.tick_params(axis='both', which='both', bottom=True, top=False, left=True, right=False)

for i in range(len(all_predictions)):
    efcs = all_efcs[i]
    Y_pred = all_predictions[i]  # shape: n_RPT x 4
    axs[0].plot(efcs, Y_pred[:, 1], '-o', markersize=2, alpha=0.5, color=colors[0])  # Cp
    axs[1].plot(efcs, Y_pred[:, 2], '-^', markersize=2, alpha=0.5, color=colors[4])  # Cn
    axs[2].plot(efcs, Y_pred[:, 3], '-s', markersize=2, alpha=0.5, color=colors[1])  # Qli
    axs[3].plot(efcs, Y_pred[:, 0], '-D', markersize=2, alpha=0.5, color=colors[3])  # Cq

axs[0].set_xlabel('EFC [Cycles]')
axs[0].set_ylabel(r'${\mathrm{C_p}}$ [Ah]')
axs[1].set_xlabel('EFC [Cycles]')
axs[1].set_ylabel(r'${\mathrm{C_n}}$ [Ah]')
axs[2].set_xlabel('EFC [Cycles]')
axs[2].set_ylabel(r'${\mathrm{Q_{li}}}$ [Ah]')
axs[3].set_xlabel('EFC [Cycles]')
axs[3].set_ylabel(r'${\mathrm{C_{q}}}$ [Ah]')

# plt.tight_layout()
plt.show()
all_predictions_test = all_predictions

#%%
fig, axs = plt.subplots(1, 4, figsize=(26 / 2.54, 5.5 / 2.54), dpi=600, gridspec_kw={'wspace': 0.45})
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top=True, right=True, which='both')
plt.tick_params(axis='both', which='both', bottom=True, top=False, left=True, right=False)

for i_idx, i in enumerate([7,216]):  #227, 38, 26, 100, 235 [7,216]
    efcs = all_efcs[i]
    print('results for cell:',all_cells[i][0])
    
    Cq_5 = np.array(all_Cq[i]) * norminal_c
    Cp_5 = np.array(all_Cp_opt[i]) * norminal_c
    Cn_5 = np.array(all_Cn_opt[i]) * norminal_c
    Cli_5 = np.array(all_Cli[i]) * norminal_c
    
    Y_pred = all_predictions[i]  
    Cq_40 = Y_pred[:, 0]
    Cp_40 = Y_pred[:, 1] 
    Cn_40 = Y_pred[:, 2] 
    Cli_40 = Y_pred[:, 3]
    
    colors = ['#DB6C6E', '#7BABD2', '#B3C786', '#B283B9','#F4BA61','#D5CA80','#9593C3']
    colors_5 = ['#DB6C6E', '#F4BA61', '#7BABD2', '#B283B9'] # C/5 
    colors_40 = ['#DB6C6E', '#F4BA61', '#7BABD2', '#B283B9'] # C/40 
    
    
    label_5 = 'C/5' if i_idx == 0 else None
    label_40 = 'C/40' if i_idx == 0 else None

    axs[0].plot(efcs, Cp_5, '--x', label=label_5, color=colors_5[0])
    axs[0].plot(efcs, Cp_40, '-o', label=label_40,markersize=5, color=colors_40[0])

    axs[1].plot(efcs, Cn_5, '--+', label=label_5, color=colors_5[1])
    axs[1].plot(efcs, Cn_40, '-^', label=label_40, markersize=5,color=colors_40[1])

    axs[2].plot(efcs, Cli_5, '--3', label=label_5, color=colors_5[2])
    axs[2].plot(efcs, Cli_40, '-s', label=label_40, markersize=5,color=colors_40[2])

    axs[3].plot(efcs, Cq_5, '--2', label=label_5, color=colors_5[3])
    axs[3].plot(efcs, Cq_40, '-D', label=label_40, markersize=5,color=colors_40[3])
    
axs[0].set_xlabel('EFC [Cycles]')
axs[0].set_ylabel(r'${\mathrm{C_p}}$ [Ah]')
axs[1].set_xlabel('EFC [Cycles]')
axs[1].set_ylabel(r'${\mathrm{C_n}}$ [Ah]')
axs[2].set_xlabel('EFC [Cycles]')
axs[2].set_ylabel(r'${\mathrm{Q_{li}}}$ [Ah]')
axs[3].set_xlabel('EFC [Cycles]')
axs[3].set_ylabel(r'${\mathrm{C_q}}$ [Ah]')
for ax in axs:
    ax.legend(loc='lower left',
        handletextpad=0.1, 
        labelspacing=0.05,
        frameon=False,
        fontsize=10)
# plt.tight_layout()
plt.show()

#%%
custom_cmap = LinearSegmentedColormap.from_list("colors", ['#e0e0e0','#606060'])
i = 7  
efcs = all_efcs[i]
Y_pred = all_predictions[i]  

Cli_all = Y_pred[:, 3] 
Cn_all = Y_pred[:, 2] 
x0_all = Cli_all / Cn_all
y0_all = np.zeros_like(x0_all)
OCV_measure = all_cell_vmea[i]
OCV_fit = np.array(all_OCV_fit[i])
plt.figure(figsize=(11.5 / 2.54, 6 / 2.54), dpi=600)
num_points = len(efcs)
norm = Normalize(vmin=efcs[0], vmax=efcs[-1])  
for idx in range(len(efcs)):
    # 
    Cq = Y_pred[idx, 0] /nominal_capacity
    Cp = Y_pred[idx, 1] /nominal_capacity
    Cn = Y_pred[idx, 2] /nominal_capacity
    Cli = Y_pred[idx, 3] /nominal_capacity
    x0 = x0_all[idx]
    y0 = y0_all[idx]

    # measured_Q：0 ~ -Cq，1000
    measured_Q = np.linspace(0, Cq, 1000)
    SOC_p = y0 + measured_Q / Cp
    SOC_n = x0 - measured_Q / Cn
    Up = OCP_p_40(SOC_p)
    Un = OCP_n_40(SOC_n)
    Voc_fit = Up - Un 
   
    color_scale = 1 / (0.05 * idx + 1)
    plt.plot(measured_Q*norminal_c, Voc_fit, alpha=0.8, color=np.array(mcolors.to_rgb(colors[1])) * color_scale,label=f'C/40 Reconstructed' if idx==0 else None)
    plt.plot(OCV_measure[idx,:,1]*norminal_c, OCV_measure[idx,:,0]*4.2,'--', color=np.array(mcolors.to_rgb(colors[0])) * color_scale, alpha=0.8, label=f'C/5 Measured' if idx==0 else None)
    plt.plot(OCV_measure[idx,:,1]*norminal_c, OCV_fit[idx,:],'-.', color=np.array(mcolors.to_rgb(colors[5])) * color_scale, alpha=0.8, label=f'C/5 Fitted' if idx==0 else None)
    
plt.xlabel('Q [Ah]')
plt.ylabel('OCV [V]')
# plt.grid(True)
sm = ScalarMappable(cmap=custom_cmap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm)
cbar.set_label('EFC [Cycles]')

tick_locs = np.linspace(efcs[0], efcs[-1], num=5)
cbar.set_ticks(tick_locs)
cbar.set_ticklabels([f"{int(t)}" for t in tick_locs])
plt.legend(loc='lower left',
    handletextpad=0.1, 
    labelspacing=0.05,
    frameon=False,
    fontsize=10)
plt.show()

#%%
custom_cmap = LinearSegmentedColormap.from_list("colors", [colors[0], colors[1]])
custom_cmap = LinearSegmentedColormap.from_list("colors", ['#e0e0e0','#606060'])
plt.figure(figsize=(11.5 / 2.54, 6 / 2.54), dpi=600)
num_points = len(efcs)
norm = Normalize(vmin=efcs[0], vmax=efcs[-1])  
for idx in range(num_points):
    color_scale = 1 / (0.05 * idx + 1)
    # color = custom_cmap(norm(efcs[idx]))
    dv_dq_fit = np.gradient(OCV_fit[idx,:], OCV_measure[idx,:,1]*norminal_c)
    # plt.plot(OCV_measure[idx,:,1]*norminal_c, -dv_dq_fit, '--', color=color,
              # label='C/5 Fitted' if idx == 0 else None)
    plt.plot(OCV_measure[idx,:,1]*norminal_c, -dv_dq_fit, '--', alpha=0.8,color=np.array(mcolors.to_rgb(colors[5])) * color_scale,
              label='C/5 Fitted' if idx == 0 else None)
    
for idx in range(num_points):
    Cq = Y_pred[idx, 0] / nominal_capacity
    Cp = Y_pred[idx, 1] / nominal_capacity
    Cn = Y_pred[idx, 2] / nominal_capacity
    Cli = Y_pred[idx, 3] / nominal_capacity
    x0 = x0_all[idx]
    y0 = y0_all[idx]

    measured_Q = np.linspace(0, Cq, 1000)
    SOC_p = y0 + measured_Q / Cp
    SOC_n = x0 - measured_Q / Cn
    Up = OCP_p_40(SOC_p)
    Un = OCP_n_40(SOC_n)
    Voc_recon = Up - Un
    dv_dq_recon = np.gradient(Voc_recon, measured_Q* norminal_c)
    # color = custom_cmap(norm(efcs[idx]))
    color_scale = 1 / (0.05 * idx + 1)

    # plt.plot(measured_Q * norminal_c, -dv_dq_recon, alpha=1, color=color,
             # label='C/40 Reconstructed' if idx == 0 else None)
    plt.plot(measured_Q * norminal_c, -dv_dq_recon, alpha=0.8, color=np.array(mcolors.to_rgb(colors[1])) * color_scale,
             label='C/40 Reconstructed' if idx == 0 else None)
    
plt.xlabel('Q [Ah]')
plt.ylabel('dV/dQ [V/Ah]')
plt.ylim([0, 1])

sm = ScalarMappable(cmap=custom_cmap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm)

cbar.set_label('EFC [Cycles]')
tick_locs = np.linspace(efcs[0], efcs[-1], num=5)
cbar.set_ticks(tick_locs)
cbar.set_ticklabels([f"{int(t)}" for t in tick_locs])
plt.legend(loc='best',
    handletextpad=0.1, 
    labelspacing=0.05,
    bbox_to_anchor=(0.35, 0.75),
    frameon=False,
    fontsize=10)
plt.show()


#%%
filename = f"saved_fittings/resval_extract_data_DE_DOF3_eucl.npz"
data = np.load(filename, allow_pickle=True)
all_Cp_opt = data["all_Cp_opt"]
all_Cn_opt = data["all_Cn_opt"]
all_x0_opt = data["all_x0_opt"]
all_y0_opt = data["all_y0_opt"]
all_cell_ocv = data["all_cell_ocv"]
all_cell_vmea = data['all_cell_vmea']
all_cells = data['all_cells']
all_OCV_fit = data["all_OCV_fit"]
all_Cq = data["all_Cq"]
all_fit_results = data['all_fit_results']
all_cell_Vreal = all_cell_ocv[:,:,0]*4.2
all_cell_Qreal = all_cell_ocv[:,:,1]
all_cell_Vm = all_cell_vmea[:,:,0]*4.2
all_cell_Qm = all_cell_vmea[:,:,1]
all_Cli = np.array([
    np.array(cp) * np.array(y0) + np.array(x0) * np.array(cn)
    for cp, y0, x0, cn in zip(all_Cp_opt, all_y0_opt, all_x0_opt, all_Cn_opt)
])


real_OCV = all_cell_Vm[94:,:]
fit_OCV = np.array(all_OCV_fit[94:,:])
Measure_Q = all_cell_Qm[94:,:]
Measure_V = all_cell_Vm[94:,:]
residual = real_OCV - fit_OCV  # shape: (94, 1000)
valid_indices = [i for i in range(94) if all_Cq[i] >= 0.7]

X_features = []
Y_targets = []
for i in valid_indices:
    # 
    voc_fit = fit_OCV[i,:]
    voc_real = real_OCV[i,:]
    q_meas = Measure_Q[i,:]
    v_meas = Measure_V[i,:]
    
    cp = all_Cp_opt[i]
    cn = all_Cn_opt[i]
    cq = all_Cq[i]
    cli = all_Cli[i]  # 
    # 
    X_i = np.column_stack([
        voc_fit,
        # q_meas,
        # v_meas,
        np.full_like(voc_fit, cp),
        np.full_like(voc_fit, cn),
        np.full_like(voc_fit, cli),
    ])
    X_features.append(X_i.flatten())  # shape: (6000,)
    Y_targets.append((voc_real - voc_fit))  # shape: (1000,)
    

X_all = np.stack(X_features)  # shape: (n_samples, 6000)
Y_all = np.stack(Y_targets)   # shape: (n_samples, 1000)
X_train, X_val, Y_train, Y_val = train_test_split(X_all, Y_all, test_size=0.2, random_state=42)
X_test, Y_test = X_val, Y_val

scaler_X = StandardScaler().fit(X_train)
scaler_Y = StandardScaler().fit(Y_train)

all_predictions = all_predictions_resval
all_cell_ocv_construct = all_cell_ocv_construct_resval
# model_electrode = joblib.load( 'saved_fittings/'+'electrode_C5_to_C40.pkl')
all_cell_Vconstruct = all_cell_ocv_construct[:,:,0]
real_OCV = all_cell_Vreal[:94]
# all_cell_Vconstruct # change to the reconstrute OCV based on the predictions of electrode states# all_OCV_fit[:94]
Measure_Q = all_cell_Qm[94:,:] # the measurement should be C/5 based
Measure_V = all_cell_Vm[94:,:]
real_Q = all_cell_Qreal[:94]
valid_indices = [i for i in range(94) if all_Cq[i] >= 0.7 ]

X_features = []
Y_targets = []
X_features_q = []
Y_targets_q = []
for i in valid_indices:
    #
    # voc_fit = fit_OCV[i,:]
    voc_fit = all_cell_Vconstruct[i,:]
    voc_real = real_OCV[i,:]
    q_meas = Measure_Q[i,:]
    v_meas = Measure_V[i,:]
    q_real = real_Q[i,:]
    
    Cq = all_predictions[i][0][0] /nominal_capacity
    Cp = all_predictions[i][0][1] /nominal_capacity
    Cn = all_predictions[i][0][2] /nominal_capacity
    Cli = all_predictions[i][0][3] /nominal_capacity
    y0 = 0
    x0 = Cli / Cn
    predict_q = np.linspace(0, Cq, 1000)
    
    X_i = np.column_stack([
        voc_fit,
        # q_meas,
        # v_meas,
        np.full_like(voc_fit, Cp),
        np.full_like(voc_fit, Cn),
        np.full_like(voc_fit, Cli),
    ])
    
    X_i_q = np.column_stack([
        predict_q,
        # q_meas,
        # v_meas,
        np.full_like(voc_fit, Cp),
        np.full_like(voc_fit, Cn),
        np.full_like(voc_fit, Cli),
    ])
    
    X_features.append(X_i.flatten())  # shape: (6000,)
    X_features_q.append(X_i_q.flatten())  # shape: (6000,)
    Y_targets.append((voc_real - voc_fit))  # shape: (1000,)
    Y_targets_q.append((q_real - predict_q))  # shape: (1000,)
    

X_all = np.stack(X_features)  # shape: (n_samples, 6000)
Y_all = np.stack(Y_targets)   # shape: (n_samples, 1000)

X_all_q = np.stack(X_features_q)  # shape: (n_samples, 6000)
Y_all_q = np.stack(Y_targets_q)   # shape: (n_samples, 1000)

scaler_X_c5_c40 = StandardScaler().fit(X_all)
scaler_Y_c5_c40 = StandardScaler().fit(Y_all)

scaler_X_c5_c40_q = StandardScaler().fit(X_all_q)
scaler_Y_c5_c40_q = StandardScaler().fit(Y_all_q)

#%%
with open("tesla_extract_data_pso.pkl", "rb") as f:
    data = pickle.load(f)

all_Cp_opt = data["all_Cp_opt"]
all_Cn_opt = data["all_Cn_opt"]
all_x0_opt = data["all_x0_opt"]
all_y0_opt = data["all_y0_opt"]
all_Cq = data["all_Cq"]
all_OCV_fit = data["all_OCV_fit"]
all_cell_cap = data["all_cell_cap"]
all_cell_ocv = data["all_cell_ocv"]
all_cell_vmea = data["all_cell_vmea"]
all_cells = data["all_cells"]
all_efcs = data["all_efcs"]

all_Cli = [
    np.array(cp) * np.array(y0) + np.array(x0) * np.array(cn)
    for cp, y0, x0, cn in zip(all_Cp_opt, all_y0_opt, all_x0_opt, all_Cn_opt)
]

final_model = joblib.load('saved_fittings/'+'final_model_C5_train1.pkl')

for i_idx, i in enumerate([test_batteries_indices[20]]):  #227, 38, 26, 100, 235 [7,216]
    efcs = all_efcs[i]
    print('results for cell:',all_cells[i][0])
    # C/5
    cell_efc = all_efcs[i]
    Cq_5 = np.array(all_Cq[i]) 
    Cp_5 = np.array(all_Cp_opt[i]) 
    Cn_5 = np.array(all_Cn_opt[i])
    Cli_5 = np.array(all_Cli[i])
    cell_ocv = all_cell_ocv[i]
    cell_Vreal = cell_ocv[:,:,0]*4.2
    cell_Qreal = cell_ocv[:,:,1]
    cell_vmea = all_cell_vmea[i]
    cell_Vmea = cell_vmea[:,:,0]*4.2
    cell_Qmea = cell_vmea[:,:,1]
    
    fit_OCV = np.array(all_OCV_fit[i])
    
    real_OCV = cell_Vmea
    X_features = []
    Y_targets = []
    for j in range(len(Cq_5)):
        # 
        voc_fit = fit_OCV[j,:]
        voc_real = real_OCV[j,:]
        q_meas = cell_Qmea[j,:]
        v_meas = cell_Vmea[j,:]
        
        cp = Cp_5[j]
        cn = Cn_5[j]
        cq = Cq_5[j]
        cli = Cli_5[j]  # 
        # 
        X_i = np.column_stack([
            voc_fit,
            # q_meas,
            # v_meas,
            np.full_like(voc_fit, cp),
            np.full_like(voc_fit, cn),
            np.full_like(voc_fit, cli),
        ])
        
        X_features.append(X_i.flatten())  # shape: (6000,)
        Y_targets.append((voc_real - voc_fit))  # shape: (1000,)
        
        
    X_all = np.stack(X_features)  # shape: (n_samples, 6000)
    Y_all = np.stack(Y_targets)   # shape: (n_samples, 1000)
    
    Y_test_pred_std = final_model.predict(scaler_X.transform(X_all))
    Y_test_pred = scaler_Y.inverse_transform(Y_test_pred_std)
    
    test_rmse = np.sqrt(mean_squared_error(Y_all, Y_test_pred))
    print(f"Test RMSE: {test_rmse:.4f}")
    
    cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap',['#E59693','#0073B1'])
    orig_color_values = np.abs(Y_all[:,:].reshape(-1)*1000-Y_test_pred[:,:].reshape(-1)*1000)
    norm = Normalize(vmin=orig_color_values.min(), vmax=orig_color_values.max())
    color_values = norm(orig_color_values)
    plt.figure(figsize=(5.5 / 2.54, 6 / 2.54), dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    scatter1 = plt.scatter(Y_all[:,:].reshape(-1)*1000, Y_test_pred[:,:].reshape(-1)*1000, 
                          c=color_values, alpha=0.9, cmap='coolwarm', marker='o', linewidth=0.0, s=10, edgecolors=None)
    plt.plot(Y_all[:,:].reshape(-1)*1000,Y_all[:,:].reshape(-1)*1000,'--',color='grey',linewidth=1)
    plt.xlabel('Real values [mV]')
    plt.ylabel('Predictions [mV]')
    cbar = plt.colorbar()
    cbar.set_label('Absolute error')
    plt.tick_params(bottom=False, left=False)
    # cbar.set_label('Normalized Color values')
    ticks = np.linspace(orig_color_values.min(), orig_color_values.max(), num=3)
    tick_labels = ["{:.2f}".format(value) for value in ticks]
    cbar.set_ticks(norm(ticks))
    cbar.set_ticklabels(tick_labels)

    ax_hist = inset_axes(
        plt.gca(),
        width="40%", height="30%",
        bbox_to_anchor=(0.44, 0.19, 0.62, 0.6),  # x0, y0, width, height 
        bbox_transform=plt.gcf().transFigure,
        loc='lower left'
    )
    ax_hist.hist(
        (Y_all[:, :].reshape(-1) * 1000 - Y_test_pred[:, :].reshape(-1) * 1000),
        bins=20, color='gray', edgecolor='black',linewidth=0.5
    )

    # 
    ax_hist.set_xlabel('Error [mV]', fontsize=6, labelpad=1)  # 
    ax_hist.set_ylabel('Count', fontsize=6, labelpad=1)
    ax_hist.tick_params(axis='both', which='major', labelsize=6)
    plt.show()
    
    peak_vals_orig = []
    peak_vals_fit = []
    peak_vals_comp = []
    
    peak_pos_orig = []
    peak_pos_fit = []
    peak_pos_comp = []
    fig, axs = plt.subplots(1, 1, figsize=(10/ 2.54, 6 / 2.54), dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    for j in range(len(Y_test_pred)): #([0,len(Y_test_pred)-1]): #range(len(Y_test_pred)): #len(Y_test_pred)
        Q = cell_Qmea[j, :] * norminal_c
        dv_dq_orig = gradient( cell_Vmea[j,:], cell_Qmea[j,:]*norminal_c )
        dv_dq_fit = gradient( fit_OCV[j,:], cell_Qmea[j,:]*norminal_c )
        # dv_dq_mea = gradient(measure_V, measure_Q)
        dv_dq_compensate = gradient( fit_OCV[j,:]+Y_test_pred[j,:], cell_Qmea[j,:]*norminal_c )
        dv_dq_orig = savgol_filter(dv_dq_orig, window_length=11, polyorder=1)
        dv_dq_fit = savgol_filter(dv_dq_fit, window_length=11, polyorder=1)
        dv_dq_compensate = savgol_filter(dv_dq_compensate, window_length=11, polyorder=1)
        # plt.plot(cell_Qmea[j,:],cell_Vmea[j,:],'r')
        # plt.plot(cell_Qmea[j,:],fit_OCV[j,:],'b--')
        # plt.plot(cell_Qmea[j,:],fit_OCV[j,:]+Y_test_pred[j,:],'m:')
        
        color_scale = 1 / (0.06 * j + 1)
        plt.plot(Q, -dv_dq_orig, color=np.array(mcolors.to_rgb(colors[0])) * color_scale, label='Real' if j == 0 else None)
        plt.plot(Q, -dv_dq_fit, '--', color=np.array(mcolors.to_rgb(colors[5])) * color_scale, label='Fitted' if j == 0 else None)
        plt.plot(Q, -dv_dq_compensate, '--', color=np.array(mcolors.to_rgb(colors[1])) * color_scale, label='Compensated' if j == 0 else None)
            
        region_mask = Q < 1.5
        Q_focus = Q[region_mask]
        
        real_focus = -dv_dq_orig[region_mask]
        fit_focus = -dv_dq_fit[region_mask]
        comp_focus = -dv_dq_compensate[region_mask]
    
        peaks_real, _ = find_peaks(real_focus, prominence=0.02)
        peaks_fit, _ = find_peaks(fit_focus, prominence=0.02)
        peaks_comp, _ = find_peaks(comp_focus, prominence=0.02)
    
        if len(peaks_real) > 0:
            peak_vals_orig.append(real_focus[peaks_real[0]])
            peak_pos_orig.append(Q_focus[peaks_real[0]])
        else:
            peak_vals_orig.append(np.nan)
            peak_pos_orig.append(np.nan)
    
        if len(peaks_fit) > 0:
            peak_vals_fit.append(fit_focus[peaks_fit[0]])
            peak_pos_fit.append(Q_focus[peaks_fit[0]])
        else:
            peak_vals_fit.append(np.nan)
            peak_pos_fit.append(np.nan)
    
        if len(peaks_comp) > 0:
            peak_vals_comp.append(comp_focus[peaks_comp[0]])
            peak_pos_comp.append(Q_focus[peaks_comp[0]])
        else:
            peak_vals_comp.append(np.nan)
            peak_pos_comp.append(np.nan)
        
    plt.ylim([0,1])
    plt.ylabel('dV/dQ [V/Ah]')
    plt.xlabel('Q [Ah]')
    axs.legend(loc='best',
              handletextpad=0.1, 
              labelspacing=0.05,
              bbox_to_anchor=(0.35, 0.55),
              frameon=False)
    plt.show()
    

fig2, ax2 = plt.subplots(figsize=(5/2.54, 6.5/2.54), dpi=600)
plt.subplot(211)
plt.plot(cell_efc, peak_vals_orig, '-o', markersize=5, color=colors[0],label='Real Peak')
plt.plot(cell_efc, peak_vals_fit, '--s', markersize=5, color=colors[5],label='Fitted Peak')
plt.plot(cell_efc, peak_vals_comp, '--^', markersize=5, color=colors[1],label='Compensated Peak')
plt.ylabel('Peak')
# plt.xlabel('Test Sample Index j')
# plt.legend()
plt.ylim(np.nanmin(peak_vals_orig + peak_vals_fit + peak_vals_comp) * 0.9,
         np.nanmax(peak_vals_orig + peak_vals_fit + peak_vals_comp) * 1.1)
plt.gca().set_xticklabels([])  
plt.grid(True)

# plt.tight_layout()
# plt.show()
plt.subplot(212)
plt.plot(cell_efc, peak_pos_orig, '-o', markersize=5, color=colors[0],label='Real Peak')
plt.plot(cell_efc, peak_pos_fit, '--s', markersize=5, color=colors[5],label='Fitted Peak')
plt.plot(cell_efc, peak_pos_comp, '--^', markersize=5, color=colors[1],label='Compensated Peak')
plt.ylabel('Position')
plt.xlabel('EFC')
# plt.legend()
plt.grid(True)
plt.ylim(np.nanmin(peak_pos_orig + peak_pos_fit + peak_pos_comp) * 0.9,
         np.nanmax(peak_pos_orig + peak_pos_fit + peak_pos_comp) * 1.1)
# plt.tight_layout()
plt.subplots_adjust(hspace=0.25)
plt.show()

#%%

X_features = []
Y_targets = []
for i_idx, i in enumerate(train_batteries_indices):  #227, 38, 26, 100, 235 [7,216]
    efcs = all_efcs[i]
    print('load data for cell:',all_cells[i][0])
    # C/5
    Cq_5 = np.array(all_Cq[i]) 
    Cp_5 = np.array(all_Cp_opt[i]) 
    Cn_5 = np.array(all_Cn_opt[i])
    Cli_5 = np.array(all_Cli[i])
    cell_ocv = all_cell_ocv[i]
    cell_Vreal = cell_ocv[:,:,0]*4.2
    cell_Qreal = cell_ocv[:,:,1]
    cell_vmea = all_cell_vmea[i]
    cell_Vmea = cell_vmea[:,:,0]*4.2
    cell_Qmea = cell_vmea[:,:,1]
    
    fit_OCV = np.array(all_OCV_fit[i])
    
    real_OCV = cell_Vmea
    
    for j in range(len(Cq_5)):
        # 
        voc_fit = fit_OCV[j,:]
        voc_real = real_OCV[j,:]
        q_meas = cell_Qmea[j,:]
        v_meas = cell_Vmea[j,:]
        
        cp = Cp_5[j]
        cn = Cn_5[j]
        cq = Cq_5[j]
        cli = Cli_5[j]  # 
        # 
        X_i = np.column_stack([
            voc_fit,
            # q_meas,
            # v_meas,
            np.full_like(voc_fit, cp),
            np.full_like(voc_fit, cn),
            np.full_like(voc_fit, cli),
        ])
        
       #### flatten 
        X_features.append(X_i.flatten())  # shape: (6000,)
        Y_targets.append((voc_real - voc_fit))  # shape: (1000,)
        
        # plt.figure(figsize=(10 / 2.54,6 / 2.54), dpi=600)
        # plt.ion()
        # plt.rcParams['xtick.direction'] = 'in'
        # plt.rcParams['ytick.direction'] = 'in'
        # plt.tick_params(top='on', right='on', which='both')
        # plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
        # plt.plot(q_meas,voc_fit)
        # plt.plot(q_meas,voc_real)
        # plt.show()
        
        
model_name = 'tesla_retrain_model_residual_C5_1.pkl'
X_all = np.stack(X_features)  # shape: (n_samples, 6000)
Y_all = np.stack(Y_targets)   # shape: (n_samples, 1000)

# X_train, X_val, Y_train, Y_val = train_test_split(X_all, Y_all, test_size=0.2, random_state=42)
X_train_std = scaler_X.transform(X_all)
Y_train_std = scaler_Y.transform(Y_all)

final_model.fit(X_train_std, Y_train_std)
# ###
joblib.dump(final_model, 'saved_fittings/'+model_name)
final_model = joblib.load('saved_fittings/'+model_name)

#%%

X_features = []
Y_targets = []
for i_idx, i in enumerate(test_batteries_indices):  #227, 38, 26, 100, 235 [7,216]
    efcs = all_efcs[i]
    print('results for cell:',all_cells[i][0])
    # C/5
    Cq_5 = np.array(all_Cq[i]) 
    Cp_5 = np.array(all_Cp_opt[i]) 
    Cn_5 = np.array(all_Cn_opt[i])
    Cli_5 = np.array(all_Cli[i])
    cell_ocv = all_cell_ocv[i]
    cell_Vreal = cell_ocv[:,:,0]*4.2
    cell_Qreal = cell_ocv[:,:,1]
    cell_vmea = all_cell_vmea[i]
    cell_Vmea = cell_vmea[:,:,0]*4.2
    cell_Qmea = cell_vmea[:,:,1]
    
    fit_OCV = np.array(all_OCV_fit[i])
    
    real_OCV = cell_Vmea
    
    for j in range(len(Cq_5)):
        # 
        voc_fit = fit_OCV[j,:]
        voc_real = real_OCV[j,:]
        q_meas = cell_Qmea[j,:]
        v_meas = cell_Vmea[j,:]
        
        cp = Cp_5[j]
        cn = Cn_5[j]
        cq = Cq_5[j]
        cli = Cli_5[j]  # 
        # 
        X_i = np.column_stack([
            voc_fit,
            # q_meas,
            # v_meas,
            np.full_like(voc_fit, cp),
            np.full_like(voc_fit, cn),
            np.full_like(voc_fit, cli),
        ])
        
       #### flatten 
        X_features.append(X_i.flatten())  # shape: (6000,)
        Y_targets.append((voc_real - voc_fit))  # shape: (1000,)
        
        # plt.figure(figsize=(10 / 2.54,6 / 2.54), dpi=600)
        # plt.ion()
        # plt.rcParams['xtick.direction'] = 'in'
        # plt.rcParams['ytick.direction'] = 'in'
        # plt.tick_params(top='on', right='on', which='both')
        # plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
        # plt.plot(q_meas,voc_fit)
        # plt.plot(q_meas,voc_real)
        # plt.show()
        
        
X_all = np.stack(X_features)  # shape: (n_samples, 6000)
Y_all = np.stack(Y_targets)   # shape: (n_samples, 1000)

Y_test_pred_std = final_model.predict(scaler_X.transform(X_all))
Y_test_pred = scaler_Y.inverse_transform(Y_test_pred_std)

test_rmse = np.sqrt(mean_squared_error(Y_all, Y_test_pred))
print(f"Test RMSE: {test_rmse:.4f}")

cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap',['#E59693','#0073B1'])
orig_color_values = np.abs(Y_all[:,:].reshape(-1)*1000-Y_test_pred[:,:].reshape(-1)*1000)
norm = Normalize(vmin=orig_color_values.min(), vmax=orig_color_values.max())
color_values = norm(orig_color_values)
plt.figure(figsize=(5.5 / 2.54,6 / 2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
scatter1 = plt.scatter(Y_all[:,:].reshape(-1)*1000, Y_test_pred[:,:].reshape(-1)*1000, 
                      c=color_values, alpha=0.8, cmap='coolwarm', marker='o', linewidth=0.0, s=10, edgecolors=None)
plt.plot([min(Y_all[:,:].reshape(-1)*1000),max(Y_all[:,:].reshape(-1)*1000)],[min(Y_all[:,:].reshape(-1)*1000),max(Y_all[:,:].reshape(-1)*1000)],'--',color='grey',linewidth=1)
plt.xlabel('Real values [mV]')
plt.ylabel('Predictions [mV]')
cbar = plt.colorbar()
cbar.set_label('Absolute error')
plt.tick_params(bottom=False, left=False)
# cbar.set_label('Normalized Color values')
ticks = np.linspace(orig_color_values.min(), orig_color_values.max(), num=3)
tick_labels = ["{:.2f}".format(value) for value in ticks]
cbar.set_ticks(norm(ticks))
cbar.set_ticklabels(tick_labels)

ax_hist = inset_axes(
    plt.gca(),
    width="40%", height="30%",
    bbox_to_anchor=(0.44, 0.19, 0.62, 0.6),  # x0, y0, width, height 
    bbox_transform=plt.gcf().transFigure,
    loc='lower left'
)
ax_hist.hist(
    (Y_all[:, :].reshape(-1) * 1000 - Y_test_pred[:, :].reshape(-1) * 1000),
    bins=20, color='gray', edgecolor='black',linewidth=0.5
)

ax_hist.set_xlabel('Error [mV]', fontsize=6, labelpad=1)  # 
ax_hist.set_ylabel('Count', fontsize=6, labelpad=1)
ax_hist.tick_params(axis='both', which='major', labelsize=6)
ax_hist.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
ax_hist.yaxis.offsetText.set_fontsize(6)
plt.show()


#%%
model_names = ['final_model_C5_train1.pkl','tesla_retrain_model_residual_C5_1.pkl']

for model_name in model_names:
    print('Test results for:', model_name)
    X_features = []
    Y_targets = []
    for i_idx, i in enumerate(test_batteries_indices):  #227, 38, 26, 100, 235 [7,216]
        efcs = all_efcs[i]
        # print('results for cell:',all_cells[i][0])
        # C/5
        Cq_5 = np.array(all_Cq[i]) 
        Cp_5 = np.array(all_Cp_opt[i]) 
        Cn_5 = np.array(all_Cn_opt[i])
        Cli_5 = np.array(all_Cli[i])
        cell_ocv = all_cell_ocv[i]
        cell_Vreal = cell_ocv[:,:,0]*4.2
        cell_Qreal = cell_ocv[:,:,1]
        cell_vmea = all_cell_vmea[i]
        cell_Vmea = cell_vmea[:,:,0]*4.2
        cell_Qmea = cell_vmea[:,:,1]
        
        fit_OCV = np.array(all_OCV_fit[i])
        
        real_OCV = cell_Vmea
        
        for j in range(len(Cq_5)):
            # 
            voc_fit = fit_OCV[j,:]
            voc_real = real_OCV[j,:]
            q_meas = cell_Qmea[j,:]
            v_meas = cell_Vmea[j,:]
            
            cp = Cp_5[j]
            cn = Cn_5[j]
            cq = Cq_5[j]
            cli = Cli_5[j]  # 
            # 
            X_i = np.column_stack([
                voc_fit,
                # q_meas,
                # v_meas,
                np.full_like(voc_fit, cp),
                np.full_like(voc_fit, cn),
                np.full_like(voc_fit, cli),
            ])
            
           #### flatten 
            X_features.append(X_i.flatten())  # shape: (6000,)
            Y_targets.append((voc_real - voc_fit))  # shape: (1000,)
            
    X_all = np.stack(X_features)  # shape: (n_samples, 6000)
    Y_all = np.stack(Y_targets)   # shape: (n_samples, 1000)
    
    final_model = joblib.load('saved_fittings/'+model_name)
    
    Y_test_pred_std = final_model.predict(scaler_X.transform(X_all))
    Y_test_pred = scaler_Y.inverse_transform(Y_test_pred_std)
    
    fit_ocv = X_all[:, 0::4]
    orig_ocv = fit_ocv + Y_all
    pre_ocv = fit_ocv + Y_test_pred
    
    
    test_rmse = np.sqrt(mean_squared_error(pre_ocv.reshape(-1)*1000, orig_ocv.reshape(-1)*1000))
    test_maae = np.max(abs(orig_ocv.reshape(-1)*1000- pre_ocv.reshape(-1)*1000))
    test_mae = mean_absolute_error(pre_ocv.reshape(-1)*1000, orig_ocv.reshape(-1)*1000)
    test_r2 = r2_score(orig_ocv.reshape(-1), pre_ocv.reshape(-1))


    fit_rmse = np.sqrt(mean_squared_error(fit_ocv.reshape(-1)*1000, orig_ocv.reshape(-1)*1000))
    fit_maae = np.max(abs(orig_ocv.reshape(-1)*1000- fit_ocv.reshape(-1)*1000))
    fit_mae = mean_absolute_error(fit_ocv.reshape(-1)*1000, orig_ocv.reshape(-1)*1000)
    fit_r2 = r2_score(orig_ocv.reshape(-1), fit_ocv.reshape(-1))
    
    # from dtaidistance import dtw
    # test_dtw = dtw.distance(pre_ocv.reshape(-1), orig_ocv.reshape(-1))
    # fit_dtw = dtw.distance(fit_ocv.reshape(-1), orig_ocv.reshape(-1))
    # print(test_dtw,fit_dtw)
    print(f"Test RMSE: {test_rmse:.4f}", f"Test MAE: {test_mae:.4f}", f"Test MaxAE: {test_maae:.4f}", f"Test R2: {test_r2:.4f}")
    print(f"Fit RMSE: {fit_rmse:.4f}", f"Fit MAE: {fit_mae:.4f}", f"Fit MaxAE: {fit_maae:.4f}", f"Fit R2: {fit_r2:.4f}")


#%% compensation from C/5 to C/40
with open("tesla_extract_data_pso.pkl", "rb") as f:
    data = pickle.load(f)

all_Cp_opt = data["all_Cp_opt"]
all_Cn_opt = data["all_Cn_opt"]
all_x0_opt = data["all_x0_opt"]
all_y0_opt = data["all_y0_opt"]
all_Cq = data["all_Cq"]
all_OCV_fit = data["all_OCV_fit"]
all_cell_cap = data["all_cell_cap"]
all_cell_ocv = data["all_cell_ocv"]
all_cell_vmea = data["all_cell_vmea"]
all_cells = data["all_cells"]
all_efcs = data["all_efcs"]

all_Cli = [
    np.array(cp) * np.array(y0) + np.array(x0) * np.array(cn)
    for cp, y0, x0, cn in zip(all_Cp_opt, all_y0_opt, all_x0_opt, all_Cn_opt)
]

all_predictions = all_predictions_test
model_name = 'final_model_C40_test_electrode_prediction_v1.pkl'
final_model_v = joblib.load('saved_fittings/'+'final_model_C40_train_electrode_prediction_v1.pkl')
final_model_q = joblib.load('saved_fittings/'+'final_model_C40_train_electrode_prediction_q1.pkl')

Y_test_all=[]
Q_test_all = []
FitV_test_all = []
EFC_test_all = []
FitQ_test_all = []
for i_idx in range(len(all_Cp_opt)):  #227, 38, 26, 100, 235 [7,216]
    X_features = []
    X_features_q = []
    Y_targets = []
    efcs = all_efcs[i_idx]
    EFC_test_all.append(efcs)
    # print('results for cell:',all_cells[i][0])
    # C/5
    Cq_5 = np.array(all_Cq[i_idx]) 
    Cp_5 = np.array(all_Cp_opt[i_idx]) 
    Cn_5 = np.array(all_Cn_opt[i_idx])
    Cli_5 = np.array(all_Cli[i_idx])
    
    # C/40
    Y_pred = all_predictions[i_idx]  # 
    Cq_40 = Y_pred[:, 0]
    Cp_40 = Y_pred[:, 1] 
    Cn_40 = Y_pred[:, 2] 
    Cli_40 = Y_pred[:, 3]
    # x0_40 = Cli_40 / Cn_40
    # y0_40 = np.zeros_like(x0_40)
    
    cell_ocv = all_cell_ocv[i_idx]
    cell_Vreal = cell_ocv[:,:,0]*4.2
    cell_Qreal = cell_ocv[:,:,1]
    cell_vmea = all_cell_vmea[i_idx]
    cell_Vmea = cell_vmea[:,:,0]*4.2
    cell_Qmea = cell_vmea[:,:,1]
    
    # fit_OCV = np.array(all_OCV_fit[i])
    predict_Q = np.zeros_like(cell_Vmea)
    # real_OCV = cell_Vmea
    predict_V = np.zeros_like(cell_Vmea)
    # 
    for j in range(len(Cq_5)):
        # 
        # voc_fit = fit_OCV[j,:]
        # voc_real = real_OCV[j,:]
        q_meas = cell_Qmea[j,:]
        v_meas = cell_Vmea[j,:]
        
        # cp = Cp_5[j]
        # cn = Cn_5[j]
        # cq = Cq_5[j]
        # cli = Cli_5[j]  # 
        
        cp_40 = Cp_40[j]/nominal_capacity
        cn_40 = Cn_40[j]/nominal_capacity
        cq_40 = Cq_40[j]/nominal_capacity
        cli_40 = Cli_40[j]/nominal_capacity  # 
        x0_40 = cli_40/cn_40
        y0_40 = 0
        measured_Q = np.linspace(0, cq_40, 1000)
        predict_Q[j,:]=measured_Q
        # SOC
        SOC_p = y0_40 + measured_Q / cp_40
        SOC_n = x0_40 - measured_Q / cn_40
    
        # 
        Up = OCP_p_40(SOC_p)
        Un = OCP_n_40(SOC_n)
        # 
        voc_fit = Up - Un 
        predict_V[j,:]=voc_fit
        # predict_q = np.linspace(0, cq_40, 1000)
        # predict_Q[j,:]=predict_q
        # 
        X_i = np.column_stack([
            voc_fit,
            # q_meas,
            # v_meas,
            np.full_like(voc_fit, cp_40),
            np.full_like(voc_fit, cn_40),
            np.full_like(voc_fit, cli_40),
        ])
        
        X_i_q = np.column_stack([
            measured_Q,
            # q_meas,
            # v_meas,
            np.full_like(voc_fit, cp_40),
            np.full_like(voc_fit, cn_40),
            np.full_like(voc_fit, cli_40),
        ])
        
       #### flatten 
        X_features.append(X_i.flatten())  # shape: (6000,)
        X_features_q.append(X_i_q.flatten())  # shape: (6000,)

    X_all = np.stack(X_features)  # shape: (n_samples, 6000)
    X_all_q = np.stack(X_features_q)  # shape: (n_samples, 6000)
    
    Y_test_pred_std = final_model_v.predict(scaler_X_c5_c40.transform(X_all))
    Y_test_pred = scaler_Y_c5_c40.inverse_transform(Y_test_pred_std)
    
    Y_test_pred_std_q = final_model_q.predict(scaler_X_c5_c40_q.transform(X_all_q))
    Y_test_pred_q = scaler_Y_c5_c40_q.inverse_transform(Y_test_pred_std_q)
    
    Y_test_all.append(Y_test_pred)
    Q_test_all.append(Y_test_pred_q)
    
    FitV_test_all.append(predict_V)
    FitQ_test_all.append(predict_Q)
    
#%%
i = 7  # 
efcs = EFC_test_all[i]
Compen_ocv = Y_test_all[i]+FitV_test_all[i]
Reconst_ocv = FitV_test_all[i]
Predict_q = FitQ_test_all[i]
Compen_q = Q_test_all[i]+FitQ_test_all[i]
cell_ocv = all_cell_ocv[i]
Real_q = cell_ocv[:,:,1]
Real_ocv = cell_ocv[:,:,0]*4.2
plt.figure(figsize=(10 / 2.54,6 / 2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)

for idx in range(len(efcs)):
    predict_Q = Predict_q[idx]
    OCV_recon = Reconst_ocv[idx]
    OCV_compen = Compen_ocv[idx]
    Q_compen = Compen_q[idx]
    dv_dq_compensate = gradient( OCV_compen, Q_compen*nominal_capacity )
    dv_dq_recon = gradient( OCV_recon, predict_Q*nominal_capacity)
    dv_dq_mea = gradient( Real_ocv[idx], Real_q[idx]*nominal_capacity )
    
    # plt.plot(Real_q[idx]*nominal_capacity, Real_ocv[idx], alpha=0.8, color=colors[3],label=f'C/5 Measured' if idx==0 else None)
    # plt.plot(predict_Q*nominal_capacity, OCV_recon, alpha=0.8, color=colors[0],label=f'C/40 Reconstructed' if idx==0 else None)
    # plt.plot(predict_Q*nominal_capacity, OCV_compen,'--', color=colors[1], alpha=0.8, label=f'C/40 Compensation' if idx==0 else None)
    
    
    # plt.plot(Real_q[idx]*nominal_capacity, -dv_dq_mea, alpha=0.8, color=colors[3],label=f'C/5 Measured' if idx==0 else None)
    
    plt.plot(predict_Q*nominal_capacity, -dv_dq_recon, alpha=0.8, color=colors[0],label=f'C/40 Reconstructed' if idx==0 else None)
    # plt.plot(predict_Q*nominal_capacity, -dv_dq_compensate,'--', color=colors[1], alpha=0.8, label=f'C/40 Compensation' if idx==0 else None)
plt.xlabel('Q [Ah]')
plt.ylabel('OCV [V]')
plt.ylim([0,1])
# plt.grid(True)
plt.legend(loc='best',
    handletextpad=0.1, 
    labelspacing=0.05,
    frameon=False,
    fontsize=10)
plt.show()

plt.figure(figsize=(10 / 2.54,6 / 2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
for idx in range(len(efcs)):
    predict_Q = Predict_q[idx]
    OCV_recon = Reconst_ocv[idx]
    OCV_compen = Compen_ocv[idx]
    Q_compen = Compen_q[idx]
    dv_dq_compensate = gradient( OCV_compen, Q_compen*nominal_capacity )
    dv_dq_recon = gradient( OCV_recon, predict_Q*nominal_capacity)
    dv_dq_mea = gradient( Real_ocv[idx], Real_q[idx]*nominal_capacity )
    
    # plt.plot(Real_q[idx]*nominal_capacity, Real_ocv[idx], alpha=0.8, color=colors[3],label=f'C/5 Measured' if idx==0 else None)
    # plt.plot(predict_Q*nominal_capacity, OCV_recon, alpha=0.8, color=colors[0],label=f'C/40 Reconstructed' if idx==0 else None)
    # plt.plot(predict_Q*nominal_capacity, OCV_compen,'--', color=colors[1], alpha=0.8, label=f'C/40 Compensation' if idx==0 else None)
    
    
    # plt.plot(Real_q[idx]*nominal_capacity, -dv_dq_mea, alpha=0.8, color=colors[3],label=f'C/5 Measured' if idx==0 else None)
    
    # plt.plot(predict_Q*nominal_capacity, -dv_dq_recon, alpha=0.8, color=colors[0],label=f'C/40 Reconstructed' if idx==0 else None)
    plt.plot(predict_Q*nominal_capacity, -dv_dq_compensate,'--', color=colors[1], alpha=0.8, label=f'C/40 Compensation' if idx==0 else None)
    
plt.xlabel('Q [Ah]')
plt.ylabel('OCV [V]')
plt.ylim([0,1])
# plt.grid(True)
plt.legend(loc='best',
    handletextpad=0.1, 
    labelspacing=0.05,
    frameon=False,
    fontsize=10)
plt.show()


plt.figure(figsize=(10 / 2.54,6 / 2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)

for idx in range(len(efcs)):
    predict_Q = Predict_q[idx]
    OCV_recon = Reconst_ocv[idx]
    OCV_compen = Compen_ocv[idx]
    Q_compen = Compen_q[idx]
    dv_dq_compensate = gradient( OCV_compen, Q_compen*nominal_capacity )
    dv_dq_recon = gradient( OCV_recon, predict_Q*nominal_capacity)
    dv_dq_mea = gradient( Real_ocv[idx], Real_q[idx]*nominal_capacity )
    
    plt.plot(Real_q[idx]*nominal_capacity, -dv_dq_mea, alpha=0.8, color=colors[5],label=f'C/5 Measured' if idx==0 else None)
    
    # plt.plot(predict_Q*nominal_capacity, -dv_dq_recon, alpha=0.8, color=colors[0],label=f'C/40 Reconstructed' if idx==0 else None)
    # plt.plot(predict_Q*nominal_capacity, -dv_dq_compensate,'--', color=colors[1], alpha=0.8, label=f'C/40 Compensation' if idx==0 else None)
plt.xlabel('Q [Ah]')
plt.ylabel('OCV [V]')
plt.ylim([0,1])
# plt.grid(True)
plt.legend(loc='best',
    handletextpad=0.1, 
    labelspacing=0.05,
    frameon=False,
    fontsize=10)
plt.show()



#%% life time prediction
all_indices = np.arange(0, len(all_efcs))  
test_batteries_indices1 = all_indices[all_indices % 2 == 0]
train_batteries_indices1 = all_indices[all_indices % 2 != 0]
coefs_all = []
for efc_treshold in [0.95,0.9,0.85,0.825, 0.8]:
    print('Results for SOH @:',efc_treshold)
    all_cycle_life = []
    # calculate cycle life
    # plt.figure(figsize=(10 / 2.54,6 / 2.54), dpi=600)
    # plt.ion()
    # plt.rcParams['xtick.direction'] = 'in'
    # plt.rcParams['ytick.direction'] = 'in'
    # plt.tick_params(top='on', right='on', which='both')
    # plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    
    for idx in range(len(all_efcs)):
        efc_each_cell = all_efcs[idx]
        cap_each_cell = all_Cq[idx]
        # cap_each_cell = cap_each_cell/cap_each_cell[0]
        # plt.plot(efc_each_cell,cap_each_cell)
        interp_efc_q = interp1d(efc_each_cell, cap_each_cell, kind='linear',bounds_error=False, fill_value="extrapolate")
        cap_interp = interp_efc_q(np.arange(efc_each_cell[0],efc_each_cell[-1]+1))
        for i in range(len(cap_interp)):
            if cap_interp[i]<efc_treshold:
                break
        all_cycle_life.append(i)
        print(np.mean(all_cycle_life))
                                  
    # plt.ylabel('SOH')  
    # plt.xlabel('EFC') 
    # plt.show()
    
    #%
    
    
    X_features_cycle = []
    Y_targets_cycle = []
    for i_idx, i in enumerate(train_batteries_indices):  #227, 38, 26, 100, 235 [7,216]
        efcs = all_efcs[i]
        # print('load data for cell:',all_cells[i][0])
        # C/5
        Cq_5 = np.array(all_Cq[i]) 
        Cp_5 = np.array(all_Cp_opt[i]) 
        Cn_5 = np.array(all_Cn_opt[i])
        Cli_5 = np.array(all_Cli[i])
        cell_ocv = all_cell_ocv[i]
        cell_Vreal = cell_ocv[:,:,0]*4.2
        cell_Qreal = cell_ocv[:,:,1]
        cell_vmea = all_cell_vmea[i]
        cell_Vmea = cell_vmea[:,:,0]*4.2
        cell_Qmea = cell_vmea[:,:,1]
        C40_pred = all_predictions_test[i]/norminal_c
        
        fit_OCV = np.array(all_OCV_fit[i])
        
        real_OCV = cell_Vmea
        
        if len(Cq_5)<3 : # 3--1   7--5
            continue
        for j in [1]:#range(len(Cq_5)):
            
            voc_fit = fit_OCV[j,:]
            voc_real = real_OCV[j,:]
            q_meas = cell_Qmea[j,:]
            v_meas = cell_Vmea[j,:]
            
            cp = Cp_5[j]
            cn = Cn_5[j]
            cq = Cq_5[j]
            cli = Cli_5[j]  # 
            
            cq_40 = C40_pred[j, 0]
            cp_40 = C40_pred[j, 1] 
            cn_40 = C40_pred[j, 2] 
            cli_40 = C40_pred[j, 3]
            
            
            # 
            dv_dq_real = gradient( voc_real, q_meas*norminal_c )
            dv_dq_real = savgol_filter(dv_dq_real, window_length=91, polyorder=1)
            dv_dq_fit = gradient( voc_fit, q_meas*norminal_c )
            dv_dq_fit = savgol_filter(dv_dq_fit, window_length=91, polyorder=1)
            dv_dq_real0 = gradient( real_OCV[0,:], cell_Qmea[0,:]*norminal_c )
            dv_dq_real0 = savgol_filter(dv_dq_real0, window_length=91, polyorder=1)
            
            dv_dq_fit0 = gradient( fit_OCV[0,:], cell_Qmea[0,:]*norminal_c )
            dv_dq_fit0 = savgol_filter(dv_dq_fit0, window_length=91, polyorder=1)
            
            
            region_mask = (q_meas*norminal_c) < 3
            Q_focus = Q[region_mask]
            real_focus = -dv_dq_real[region_mask]
            fit_focus = -dv_dq_fit[region_mask]
            real_focus0 = -dv_dq_real0[region_mask]
            fit_focus0 = -dv_dq_fit0[region_mask]
            
            peaks_real, _ = find_peaks(real_focus, prominence=0.01)
            peaks_fit, _ = find_peaks(fit_focus, prominence=0.01)
            peaks_real0, _ = find_peaks(real_focus0, prominence=0.01)
            peaks_fit0, _ = find_peaks(fit_focus0, prominence=0.01)
            
            peaks_real_val= real_focus[peaks_real[0]]
            peaks_fit_val= fit_focus[peaks_fit[0]]
            peaks_real_val0= real_focus0[peaks_real0[0]]
            peaks_fit_val0= fit_focus0[peaks_fit0[0]]
            
            
            
            X_i = np.column_stack([
                
                cq,
                cq-Cq_5[0],
                cp,
                cp-Cp_5[0],
                cn,
                cn-Cn_5[0],
                cli,
                cli-Cli_5[0],
                
                # cq_40,
                # cq_40-C40_pred[0, 0],
                # cp_40,
                # cq_40-C40_pred[0, 1],
                # cn_40,
                # cq_40-C40_pred[0, 2],
                # cli_40,
                # cq_40-C40_pred[0, 3],
                
                np.var(voc_real-voc_fit),
                np.var(voc_real-real_OCV[0,:]),
                np.var(voc_fit-real_OCV[0,:]),
                np.max(abs(voc_real-voc_fit)),
                np.var(dv_dq_real[200:-200]-dv_dq_fit[200:-200]),
                np.var(dv_dq_real0[200:-200]-dv_dq_real[200:-200]),
                np.max(dv_dq_real[200:-200]-dv_dq_fit[200:-200]),
                np.min(dv_dq_real[200:-200]-dv_dq_fit[200:-200]),
                # efcs[j],
                peaks_real[0],
                peaks_fit[0],
                peaks_real_val,
                peaks_fit_val,
                abs(peaks_real_val0-peaks_real_val),
                abs(peaks_real0[0]-peaks_real[0]),
                abs(peaks_fit_val0-peaks_fit_val),
                abs(peaks_fit0[0]-peaks_fit[0]),
                efcs[j],
            ])
            
           #### flatten 
            X_features_cycle.append(X_i.flatten())  # shape: (6000,)
            Y_targets_cycle.append((np.log(all_cycle_life[i])))  # shape: (1000,)
            
    model_name = 'tesla_cycle_life_C5.pkl'
    X_all = np.stack(X_features_cycle)  # shape: (n_samples, 6000)
    Y_all = np.stack(Y_targets_cycle).reshape(-1,1)   # shape: (n_samples, 1000)
    
    # print(X_all.shape)
    print(np.mean(X_all[:,-1]),np.mean(np.exp(Y_all)),np.mean(X_all[:,0]))
    X_train, X_val, Y_train, Y_val = train_test_split(X_all, Y_all, test_size=0.2, random_state=42)
    
    np.random.seed(123)
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    alphas = np.logspace(-3, 1, 30)
    best_alphas = []
    for train_index, val_index in kf.split(X_all):
        X_train, X_val = X_all[train_index], X_all[val_index]
        Y_train, Y_val = Y_all[train_index], Y_all[val_index]
        
        #
        scaler_X_cycle  = StandardScaler().fit(X_train)
        scaler_Y_cycle  = StandardScaler().fit(Y_train)
        X_train_scaled = scaler_X_cycle.transform(X_train)
        X_val_scaled = scaler_X_cycle.transform(X_val)
        Y_train_scaled = scaler_Y_cycle.transform(Y_train)
        Y_val_scaled = scaler_Y_cycle.transform(Y_val)
    
        best_alpha, best_score = None, float('inf')
        for alpha in alphas:
            model = Lasso(alpha=alpha)
            model.fit(X_train_scaled, Y_train_scaled)
            Y_val_pred = scaler_Y_cycle.inverse_transform(model.predict(X_val_scaled).reshape(-1,1))
            Y_val_true = scaler_Y_cycle.inverse_transform(Y_val_scaled)
            rmse = np.sqrt(((Y_val_pred - Y_val_true) ** 2).mean())
            if rmse < best_score:
                best_alpha = alpha
                best_score = rmse
                joblib.dump(model, 'saved_fittings/'+'validation_model_cycle_life.pkl')
        
        best_alphas.append(best_alpha)
        # print(f"Fold best alpha: {best_alpha:.4f}, RMSE: {best_score:.4f}")
    
    final_alpha = np.mean(best_alphas)
    # print(f"\nFinal averaged alpha from 5 folds: {final_alpha:.4f}")
    
    #% training and validation test
    X_train, X_val, Y_train, Y_val = train_test_split(X_all, Y_all, test_size=0.3, random_state=42)
    scaler_X_cycle  = StandardScaler().fit(X_train)
    scaler_Y_cycle  = StandardScaler().fit(Y_train)
    X_train_scaled = scaler_X_cycle.transform(X_train)
    X_val_scaled = scaler_X_cycle.transform(X_val)
    Y_train_scaled = scaler_Y_cycle.transform(Y_train)
    Y_val_scaled = scaler_Y_cycle.transform(Y_val)
    
    model= joblib.load('saved_fittings/'+'validation_model_cycle_life.pkl')
    Y_val_pred = scaler_Y_cycle.inverse_transform(model.predict(X_val_scaled).reshape(-1,1))
    Y_val_true = scaler_Y_cycle.inverse_transform(Y_val_scaled)
    rmse = np.sqrt(((Y_val_pred - Y_val_true) ** 2).mean())
    
    Y_train_pred = scaler_Y_cycle.inverse_transform(model.predict(X_train_scaled).reshape(-1,1))
    Y_train_true = scaler_Y_cycle.inverse_transform(Y_train_scaled)
    rmse = np.sqrt(((Y_val_pred - Y_val_true) ** 2).mean())
    
    # 
    scaler_X_full_cycle = StandardScaler().fit(X_all)
    scaler_Y_full_cycle  = StandardScaler().fit(Y_all)
    
    X_scaled = scaler_X_full_cycle.transform(X_all)
    Y_scaled = scaler_Y_full_cycle.transform(Y_all)
    
    
    cycle_model = Lasso(alpha=final_alpha)
    cycle_model.fit(X_scaled, Y_scaled)
    joblib.dump(cycle_model, 'saved_fittings/'+model_name)
    cycle_model = joblib.load('saved_fittings/'+model_name)
    
    Y_test_pred_std = cycle_model.predict(X_scaled)
    Y_test_pred = scaler_Y_full_cycle.inverse_transform(Y_test_pred_std.reshape(-1,1))
    
    
    # cycle_model = XGBRegressor(
    #     n_estimators=100,
    #     max_depth=10,
    #     learning_rate=0.12,
    #     subsample=0.86,
    #     colsample_bytree=0.8,
    #     # reg_alpha=0.8,
    #     # reg_lambda=20,
    #     random_state=42
    # )
    # cycle_model.fit(X_scaled, Y_scaled.ravel())  # ravel() 
    # Y_test_pred_std = cycle_model.predict(X_scaled)
    # Y_test_pred = scaler_Y_full_cycle.inverse_transform(Y_test_pred_std.reshape(-1, 1))
    
    # Y_test_pred = np.exp(Y_test_pred)
    # Y_all = np.exp(Y_all)
    # test_rmse = np.sqrt(mean_squared_error(Y_all, Y_test_pred))
    # print(f"Test RMSE: {test_rmse:.4f}")
    
    #%
    X_features_cycle = []
    Y_targets_cycle = []
    for i_idx, i in enumerate(test_batteries_indices):  #227, 38, 26, 100, 235 [7,216]
        efcs = all_efcs[i]
        # print('load data for cell:',all_cells[i][0])
        # C/5
        Cq_5 = np.array(all_Cq[i]) 
        Cp_5 = np.array(all_Cp_opt[i]) 
        Cn_5 = np.array(all_Cn_opt[i])
        Cli_5 = np.array(all_Cli[i])
        cell_ocv = all_cell_ocv[i]
        cell_Vreal = cell_ocv[:,:,0]*4.2
        cell_Qreal = cell_ocv[:,:,1]
        cell_vmea = all_cell_vmea[i]
        cell_Vmea = cell_vmea[:,:,0]*4.2
        cell_Qmea = cell_vmea[:,:,1]
        C40_pred = all_predictions_test[i]/norminal_c
        
        fit_OCV = np.array(all_OCV_fit[i])
        
        real_OCV = cell_Vmea
        
        if len(Cq_5)<3 : # 3--1  7--5
            continue
        for j in [1]:#range(len(Cq_5)):
            
            # 
            voc_fit = fit_OCV[j,:]
            voc_real = real_OCV[j,:]
            q_meas = cell_Qmea[j,:]
            v_meas = cell_Vmea[j,:]
            
            cp = Cp_5[j]
            cn = Cn_5[j]
            cq = Cq_5[j]
            cli = Cli_5[j]  # 
            
            cq_40 = C40_pred[j, 0]
            cp_40 = C40_pred[j, 1] 
            cn_40 = C40_pred[j, 2] 
            cli_40 = C40_pred[j, 3]
            
            
            # 
            dv_dq_real = gradient( voc_real, q_meas*norminal_c )
            dv_dq_real = savgol_filter(dv_dq_real, window_length=91, polyorder=1)
            dv_dq_fit = gradient( voc_fit, q_meas*norminal_c )
            dv_dq_fit = savgol_filter(dv_dq_fit, window_length=91, polyorder=1)
            dv_dq_real0 = gradient( real_OCV[0,:], cell_Qmea[0,:]*norminal_c )
            dv_dq_real0 = savgol_filter(dv_dq_real0, window_length=91, polyorder=1)
            
            dv_dq_fit0 = gradient( fit_OCV[0,:], cell_Qmea[0,:]*norminal_c )
            dv_dq_fit0 = savgol_filter(dv_dq_fit0, window_length=91, polyorder=1)
            
            
            region_mask = (q_meas*norminal_c) < 3
            Q_focus = Q[region_mask]
            real_focus = -dv_dq_real[region_mask]
            fit_focus = -dv_dq_fit[region_mask]
            real_focus0 = -dv_dq_real0[region_mask]
            fit_focus0 = -dv_dq_fit0[region_mask]
            
            peaks_real, _ = find_peaks(real_focus, prominence=0.01)
            peaks_fit, _ = find_peaks(fit_focus, prominence=0.01)
            peaks_real0, _ = find_peaks(real_focus0, prominence=0.01)
            peaks_fit0, _ = find_peaks(fit_focus0, prominence=0.01)
            
            peaks_real_val= real_focus[peaks_real[0]]
            peaks_fit_val= fit_focus[peaks_fit[0]]
            peaks_real_val0= real_focus0[peaks_real0[0]]
            peaks_fit_val0= fit_focus0[peaks_fit0[0]]
            
            
            
            X_i = np.column_stack([
                
                cq,
                cq-Cq_5[0],
                cp,
                cp-Cp_5[0],
                cn,
                cn-Cn_5[0],
                cli,
                cli-Cli_5[0],
                
                # cq_40,
                # cq_40-C40_pred[0, 0],
                # cp_40,
                # cq_40-C40_pred[0, 1],
                # cn_40,
                # cq_40-C40_pred[0, 2],
                # cli_40,
                # cq_40-C40_pred[0, 3],
                
                np.var(voc_real-voc_fit),
                np.var(voc_real-real_OCV[0,:]),
                np.var(voc_fit-real_OCV[0,:]),
                np.max(abs(voc_real-voc_fit)),
                np.var(dv_dq_real[200:-200]-dv_dq_fit[200:-200]),
                np.var(dv_dq_real0[200:-200]-dv_dq_real[200:-200]),
                np.max(dv_dq_real[200:-200]-dv_dq_fit[200:-200]),
                np.min(dv_dq_real[200:-200]-dv_dq_fit[200:-200]),
                # efcs[j],
                peaks_real[0],
                peaks_fit[0],
                peaks_real_val,
                peaks_fit_val,
                abs(peaks_real_val0-peaks_real_val),
                abs(peaks_real0[0]-peaks_real[0]),
                abs(peaks_fit_val0-peaks_fit_val),
                abs(peaks_fit0[0]-peaks_fit[0]),
                efcs[j],
            ])
            
            
           #### flatten 
            X_features_cycle.append(X_i.flatten())  # shape: (6000,)
            Y_targets_cycle.append((np.log(all_cycle_life[i])))  # shape: (1000,)
            
            
    X_test_all = np.stack(X_features_cycle)  # shape: (n_samples, 6000)
    Y_test_all = np.stack(Y_targets_cycle).reshape(-1,1)   # shape: (n_samples, 1000)
    X_scaled = scaler_X_full_cycle.transform(X_test_all)
    Y_scaled = scaler_Y_full_cycle.transform(Y_test_all)
    # print(X_test_all.shape)
    print(np.mean(X_test_all[:,-1]),np.mean(X_test_all[:,0]))
    Y_test_pred_std = cycle_model.predict(X_scaled)
    Y_test_pred = scaler_Y_full_cycle.inverse_transform(Y_test_pred_std.reshape(-1,1))
    
    Y_test_pred = np.exp(Y_test_pred)
    Y_test_all = np.exp(Y_test_all)
    test_rmse = np.sqrt(mean_squared_error(Y_test_all, Y_test_pred))
    test_mae = mean_absolute_error(Y_test_all, Y_test_pred)
    test_pmae = np.mean(abs(Y_test_all-Y_test_pred)/Y_test_all)
    test_r2 = r2_score(Y_test_all.reshape(-1), Y_test_pred.reshape(-1))
    
    print(f"Test RMSE: {test_rmse:.4f}", f"Test MAE: {test_mae:.4f}", f"Test MPAE: {test_pmae:.4f}", f"Test R2: {test_r2:.4f}")
    
    Y_train_true = np.exp(Y_train_true)
    Y_train_pred = np.exp(Y_train_pred)
    
    Y_val_true = np.exp(Y_val_true)
    Y_val_pred = np.exp(Y_val_pred)
    
    train_rmse = np.sqrt(mean_squared_error(Y_train_true, Y_train_pred))
    train_mae = mean_absolute_error(Y_train_true, Y_train_pred)
    train_pmae = np.mean(abs(Y_train_true-Y_train_pred)/Y_train_true)
    train_r2 = r2_score(Y_train_true.reshape(-1), Y_train_pred.reshape(-1))
    
    print(f"Train RMSE: {train_rmse:.4f}", f"Train MAE: {train_mae:.4f}", f"Train MPAE: {train_pmae:.4f}", f"Train R2: {train_r2:.4f}")
    
    vali_rmse = np.sqrt(mean_squared_error(Y_val_true, Y_val_pred))
    vali_mae = mean_absolute_error(Y_val_true, Y_val_pred)
    vali_pmae = np.mean(abs(Y_val_true-Y_val_pred)/Y_val_true)
    vali_r2 = r2_score(Y_val_true.reshape(-1), Y_val_pred.reshape(-1))
    
    print(f"Vali RMSE: {vali_rmse:.4f}", f"Vali MAE: {vali_mae:.4f}", f"Vali MPAE: {vali_pmae:.4f}", f"Vali R2: {vali_r2:.4f}")
    
    
    Y_all_true = np.concatenate([Y_train_true, Y_val_true, Y_test_all])
    Y_all_pred = np.concatenate([Y_train_pred, Y_val_pred, Y_test_pred])
    print( r2_score(Y_all_true.reshape(-1), Y_all_pred.reshape(-1)))
    
    plt.figure(figsize=(6.5 / 2.54, 6 / 2.54), dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    plt.scatter(Y_train_true,Y_train_pred,color=colors[0],marker='s',linewidths=0.8,edgecolors='black', alpha=0.8, label='Train')
    plt.scatter(Y_val_true,Y_val_pred,color=colors[5],marker='^',linewidths=0.8, edgecolors='black',alpha=0.8, label='Validation 1')
    plt.scatter(Y_test_all,Y_test_pred,color=colors[1],marker='o',linewidths=0.8, edgecolors='black', alpha=0.8, label='Validation 2')
    # plt.plot(Y_all_true,Y_all_true,color='grey')
    plt.plot(np.arange(-20,1.1*max(Y_all_true)),np.arange(-20,1.1*max(Y_all_true)),'--',color='grey')
    plt.legend(loc='upper left',
        handletextpad=0.1, 
        labelspacing=0.05,
        frameon=False,
        bbox_to_anchor=(-0.05, 1.05),
        fontsize=10)
    plt.xlabel('Actual EFC')
    plt.xlim([-20,1.1*max(Y_all_true)])
    plt.ylim([-20,1.1*max(Y_all_true)])
              
    plt.ylabel('Predicted EFC')
    plt.show()
    
    coefs = np.abs(cycle_model.coef_)
    coefs_all.append(coefs)
    
    #%%
#
 # cq,
 # cq-Cq_5[0],
 # cp,
 # cp-Cp_5[0],
 # cn,
 # cn-Cn_5[0],
 # cli,
 # cli-Cli_5[0],
 # np.var(voc_real-voc_fit),
 # np.var(voc_real-real_OCV[0,:]),
 # np.var(voc_fit-real_OCV[0,:]),
 # np.max(abs(voc_real-voc_fit)),
 
 # np.var(dv_dq_real[200:-200]-dv_dq_fit[200:-200]),
 # np.var(dv_dq_real0[200:-200]-dv_dq_real[200:-200]),
 # np.max(dv_dq_real[200:-200]-dv_dq_fit[200:-200]),
 # np.min(dv_dq_real[200:-200]-dv_dq_fit[200:-200]),
 # # efcs[j],
 # peaks_real[0],
 # peaks_fit[0],
 # peaks_real_val,
 # peaks_fit_val,

 # abs(peaks_real_val0-peaks_real_val),
 # abs(peaks_real0[0]-peaks_real[0]),
 # abs(peaks_fit_val0-peaks_fit_val),
 # abs(peaks_fit0[0]-peaks_fit[0]),
 # efcs[j],
 
 
feature_names = [
    r'${\mathrm{C_q}}$', r'${\Delta \mathrm{C_q}}$', 
    r'${\mathrm{C_p}}$', r'${\Delta \mathrm{C_p}}$', 
    r'${\mathrm{C_n}}$', r'${\Delta \mathrm{C_n}}$', 
    r'${\mathrm{Q_{li}}}$', r'${\Delta \mathrm{Q_{li}}}$',
    
    # r'${\mathrm{C_q40}}$', r'${\Delta \mathrm{C_q40}}$', 
    # r'${\mathrm{C_p40}}$',r'${\Delta \mathrm{C_p40}}$', 
    # r'${\mathrm{C_n40}}$', r'${\Delta \mathrm{C_n40}}$', 
    # r'${\mathrm{Q_li40}}$', r'${\Delta \mathrm{Q_{li40}}}$',
    
    r'$\mathrm{var}(\mathrm{V_m} - \mathrm{V_f})$', 
    r'$\mathrm{var}(\Delta \mathrm{V_m})$', 
    # r'$\mathrm{var}(V_{\mathrm{m}} - V_{\mathrm{m0}})$', 
    r'$\mathrm{var}(\mathrm{V_f} - \mathrm{V_m0})$',
    r'$\mathrm{max}(\mathrm{V_m} - \mathrm{V_f})$', 
    
    r'$\mathrm{var}(\mathrm{dV_m} - \mathrm{dV_f})$', 
    r'$\mathrm{var}(\Delta \mathrm{dV_m})$', 
    r'$\mathrm{max}(\mathrm{dV_m} - \mathrm{dV_f})$', 
    r'$\mathrm{min}(\mathrm{dV_m} - \mathrm{dV_f})$', 
    
    
    r'${\mathrm{Peak Loc_m}}$',  r'${\mathrm{Peak Loc_f}}$', 
    r'${\mathrm{Peak_m}}$',  r'${\mathrm{Peak_f}}$', 
    r'$\Delta \mathrm{Peak_m}$', 
    r'$\Delta \mathrm{Peak Loc_m}$', 
    r'$\Delta \mathrm{Peak_f}$', 
    r'$\Delta \mathrm{Peak Loc_f}$',
    r'$\mathrm{Current~EFC}$'
]


# feature_names = [
#     "cq", "cq-Cq5", "cp", "cp-Cp5", "cn", "cn-Cn5", "cli", "cli-Cli5",
#     "var(v_real - v_fit)", "var(v_real - v0)", "var(v_fit - v0)", "max|v_real - v_fit|",
#     "var(dv_dq - fit)", "var(dv_dq0 - dq)", "max(dv_dq err)", "min(dv_dq err)",
#     "peak_real_idx", "peak_fit_idx", "peak_real_val", "peak_fit_val",
#     "Δpeak_val_real", "Δpeak_idx_real", "Δpeak_val_fit", "Δpeak_idx_fit", "efc"
# ]
cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap',['#0073B1','#E59693',colors[5]]) #

cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap',['#E59693',colors[4],colors[5],colors[2],colors[1],colors[6]]) #

# coefs = np.abs(cycle_model.coef_)
# coefs_all.append(coefs)

nonzero_mask = coefs > 1e-2
coefs_nonzero = coefs[nonzero_mask]
feature_names_nonzero = np.array(feature_names)[nonzero_mask]


sorted_indices = np.argsort(coefs_nonzero)  # 从小到大
sorted_coefs = coefs_nonzero[sorted_indices]
sorted_names = feature_names_nonzero[sorted_indices]


theta = np.linspace(0.0, 2 * np.pi, len(sorted_coefs), endpoint=False)
radii = sorted_coefs
width = 2 * np.pi / len(radii)
colors_bar = [cmap(i / (len(radii)-1)) for i in range(len(radii))]


fig, ax = plt.subplots(subplot_kw=dict(polar=True), figsize=(6.5 / 2.54, 6 / 2.54), dpi=600)
ax.bar(
    x=theta,
    height=radii,
    width=width,
    color=colors_bar,  
    edgecolor='black',
    linewidth=0.5,
    alpha=0.9
)

ax.set_xticks(theta)
ax.set_xticklabels(sorted_names)  # 

ax.set_yticklabels([])
ax.set_facecolor("none")  # 
ax.grid(False)            # 
ax.spines['polar'].set_visible(False)  # 


ax.set_yticklabels([])  # 
ax.set_xticks(theta)
ax.set_xticklabels([])
for angle, radius in zip(theta, radii):
    # 
    x_text = angle
    y_text = radius + max(radii)*0.12  # 

    
    ax.text(
        x_text, y_text,
        f"{radius:.2f}",
        ha='center',
        va='center',
        fontsize=8,
        rotation=np.degrees(angle),  # 
        rotation_mode='anchor'
    )
    
fig.tight_layout(pad=0)
plt.show()

fig, ax = plt.subplots(figsize=(11 / 2.54, 8 / 2.54), dpi=600)
y_pos = np.arange(len(radii))

ax.barh(
    y=y_pos,         #
    width=radii,     # 
    height=0.65,  # 
    color=colors_bar,
    edgecolor='black',
    linewidth=0.5,
    alpha=0.9
)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.set_yticks(y_pos)
ax.set_yticklabels(sorted_names)

# plt.ylim(0.5, len(radii)-0.1)
ax.invert_yaxis()
plt.tight_layout()
plt.show()


fig, ax = plt.subplots( figsize=(24 / 2.54,4 / 2.54), dpi=600)
sns.heatmap(np.array(np.array(coefs_all)[::-1]), cmap=cmap, alpha=0.9, cbar_kws={'label': 'Feature contribution'})
ax.set_xticklabels(feature_names, rotation=90)
# ax.set_yticklabels(['EFC @ 95% SOH','EFC @ 90% SOH','EFC @ 85% SOH','EFC @ 82.5% SOH','EFC @ 80% SOH'], rotation=0)
ax.set_yticklabels(['EFC @ 80% SOH','EFC @ 82.5% SOH','EFC @ 85% SOH','EFC @ 90% SOH','EFC @ 95% SOH'], rotation=0)
plt.show()


#%%
mean_abs_shap_all=[]
for efc_treshold in  [0.95,0.9,0.85,0.825, 0.8]:
    print('Results for SOH @:',efc_treshold)
    all_cycle_life = []
   
    for idx in range(len(all_efcs)):
        efc_each_cell = all_efcs[idx]
        cap_each_cell = all_Cq[idx]
        # cap_each_cell = cap_each_cell/cap_each_cell[0]
        # plt.plot(efc_each_cell,cap_each_cell)
        interp_efc_q = interp1d(efc_each_cell, cap_each_cell, kind='linear',bounds_error=False, fill_value="extrapolate")
        cap_interp = interp_efc_q(np.arange(efc_each_cell[0],efc_each_cell[-1]+1))
        for i in range(len(cap_interp)):
            if cap_interp[i]<efc_treshold:
                break
        all_cycle_life.append(i)
                                  
    
    X_features_cycle = []
    Y_targets_cycle = []
    for i_idx, i in enumerate(train_batteries_indices):  #227, 38, 26, 100, 235 [7,216]
        efcs = all_efcs[i]
        # print('load data for cell:',all_cells[i][0])
        # C/5
        Cq_5 = np.array(all_Cq[i]) 
        Cp_5 = np.array(all_Cp_opt[i]) 
        Cn_5 = np.array(all_Cn_opt[i])
        Cli_5 = np.array(all_Cli[i])
        cell_ocv = all_cell_ocv[i]
        cell_Vreal = cell_ocv[:,:,0]*4.2
        cell_Qreal = cell_ocv[:,:,1]
        cell_vmea = all_cell_vmea[i]
        cell_Vmea = cell_vmea[:,:,0]*4.2
        cell_Qmea = cell_vmea[:,:,1]
        C40_pred = all_predictions_test[i]/norminal_c
        
        fit_OCV = np.array(all_OCV_fit[i])
        
        real_OCV = cell_Vmea
        
        if len(Cq_5)<3 : # 3--1   7--5
            continue
        for j in [1]:#range(len(Cq_5)):
            
            # 
            voc_fit = fit_OCV[j,:]
            voc_real = real_OCV[j,:]
            q_meas = cell_Qmea[j,:]
            v_meas = cell_Vmea[j,:]
            
            cp = Cp_5[j]
            cn = Cn_5[j]
            cq = Cq_5[j]
            cli = Cli_5[j]  # 
            
            cq_40 = C40_pred[j, 0]
            cp_40 = C40_pred[j, 1] 
            cn_40 = C40_pred[j, 2] 
            cli_40 = C40_pred[j, 3]
            
            
            # 
            dv_dq_real = gradient( voc_real, q_meas*norminal_c )
            dv_dq_real = savgol_filter(dv_dq_real, window_length=91, polyorder=1)
            dv_dq_fit = gradient( voc_fit, q_meas*norminal_c )
            dv_dq_fit = savgol_filter(dv_dq_fit, window_length=91, polyorder=1)
            dv_dq_real0 = gradient( real_OCV[0,:], cell_Qmea[0,:]*norminal_c )
            dv_dq_real0 = savgol_filter(dv_dq_real0, window_length=91, polyorder=1)
            
            dv_dq_fit0 = gradient( fit_OCV[0,:], cell_Qmea[0,:]*norminal_c )
            dv_dq_fit0 = savgol_filter(dv_dq_fit0, window_length=91, polyorder=1)
            
            
            region_mask = (q_meas*norminal_c) < 3
            Q_focus = Q[region_mask]
            real_focus = -dv_dq_real[region_mask]
            fit_focus = -dv_dq_fit[region_mask]
            real_focus0 = -dv_dq_real0[region_mask]
            fit_focus0 = -dv_dq_fit0[region_mask]
            
            peaks_real, _ = find_peaks(real_focus, prominence=0.01)
            peaks_fit, _ = find_peaks(fit_focus, prominence=0.01)
            peaks_real0, _ = find_peaks(real_focus0, prominence=0.01)
            peaks_fit0, _ = find_peaks(fit_focus0, prominence=0.01)
            
            peaks_real_val= real_focus[peaks_real[0]]
            peaks_fit_val= fit_focus[peaks_fit[0]]
            peaks_real_val0= real_focus0[peaks_real0[0]]
            peaks_fit_val0= fit_focus0[peaks_fit0[0]]
            
            
            
            X_i = np.column_stack([
                
                cq,
                cq-Cq_5[0],
                cp,
                cp-Cp_5[0],
                cn,
                cn-Cn_5[0],
                cli,
                cli-Cli_5[0],
              
                
                np.var(voc_real-voc_fit),
                np.var(voc_real-real_OCV[0,:]),
                np.var(voc_fit-real_OCV[0,:]),
                np.max(abs(voc_real-voc_fit)),
                np.var(dv_dq_real[200:-200]-dv_dq_fit[200:-200]),
                np.var(dv_dq_real0[200:-200]-dv_dq_real[200:-200]),
                np.max(dv_dq_real[200:-200]-dv_dq_fit[200:-200]),
                np.min(dv_dq_real[200:-200]-dv_dq_fit[200:-200]),
                # efcs[j],
                peaks_real[0],
                peaks_fit[0],
                peaks_real_val,
                peaks_fit_val,
                abs(peaks_real_val0-peaks_real_val),
                abs(peaks_real0[0]-peaks_real[0]),
                abs(peaks_fit_val0-peaks_fit_val),
                abs(peaks_fit0[0]-peaks_fit[0]),
                efcs[j],
            ])
            
           #### flatten 
            X_features_cycle.append(X_i.flatten())  # shape: (6000,)
            Y_targets_cycle.append((np.log(all_cycle_life[i])))  # shape: (1000,)
            
    model_name = 'tesla_cycle_life_C5.pkl'
    X_all = np.stack(X_features_cycle)  # shape: (n_samples, 6000)
    Y_all = np.stack(Y_targets_cycle).reshape(-1,1)   # shape: (n_samples, 1000)
    
    # print(X_all.shape)
    print(np.mean(X_all[:,-1]),np.mean(np.exp(Y_all)),np.mean(X_all[:,0]))
    X_train, X_val, Y_train, Y_val = train_test_split(X_all, Y_all, test_size=0.2, random_state=42)
    
    np.random.seed(123)
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    alphas = np.logspace(-3, 1, 30)
    best_alphas = []
    for train_index, val_index in kf.split(X_all):
        X_train, X_val = X_all[train_index], X_all[val_index]
        Y_train, Y_val = Y_all[train_index], Y_all[val_index]
        
        # 
        scaler_X_cycle  = StandardScaler().fit(X_train)
        scaler_Y_cycle  = StandardScaler().fit(Y_train)
        X_train_scaled = scaler_X_cycle.transform(X_train)
        X_val_scaled = scaler_X_cycle.transform(X_val)
        Y_train_scaled = scaler_Y_cycle.transform(Y_train)
        Y_val_scaled = scaler_Y_cycle.transform(Y_val)
    
        best_alpha, best_score = None, float('inf')
        for alpha in alphas:
            model = Lasso(alpha=alpha)
            model.fit(X_train_scaled, Y_train_scaled)
            Y_val_pred = scaler_Y_cycle.inverse_transform(model.predict(X_val_scaled).reshape(-1,1))
            Y_val_true = scaler_Y_cycle.inverse_transform(Y_val_scaled)
            rmse = np.sqrt(((Y_val_pred - Y_val_true) ** 2).mean())
            if rmse < best_score:
                best_alpha = alpha
                best_score = rmse
                joblib.dump(model, 'saved_fittings/'+'validation_model_cycle_life.pkl')
        
        best_alphas.append(best_alpha)
        # print(f"Fold best alpha: {best_alpha:.4f}, RMSE: {best_score:.4f}")
    
    # 
    final_alpha = np.mean(best_alphas)
    # print(f"\nFinal averaged alpha from 5 folds: {final_alpha:.4f}")
    
    #% training and validation test
    X_train, X_val, Y_train, Y_val = train_test_split(X_all, Y_all, test_size=0.3, random_state=42)
    scaler_X_cycle  = StandardScaler().fit(X_train)
    scaler_Y_cycle  = StandardScaler().fit(Y_train)
    X_train_scaled = scaler_X_cycle.transform(X_train)
    X_val_scaled = scaler_X_cycle.transform(X_val)
    Y_train_scaled = scaler_Y_cycle.transform(Y_train)
    Y_val_scaled = scaler_Y_cycle.transform(Y_val)
    
    model= joblib.load('saved_fittings/'+'validation_model_cycle_life.pkl')
    Y_val_pred = scaler_Y_cycle.inverse_transform(model.predict(X_val_scaled).reshape(-1,1))
    Y_val_true = scaler_Y_cycle.inverse_transform(Y_val_scaled)
    rmse = np.sqrt(((Y_val_pred - Y_val_true) ** 2).mean())
    
    Y_train_pred = scaler_Y_cycle.inverse_transform(model.predict(X_train_scaled).reshape(-1,1))
    Y_train_true = scaler_Y_cycle.inverse_transform(Y_train_scaled)
    rmse = np.sqrt(((Y_val_pred - Y_val_true) ** 2).mean())
    
    # 
    scaler_X_full_cycle = StandardScaler().fit(X_all)
    scaler_Y_full_cycle  = StandardScaler().fit(Y_all)
    
    X_scaled = scaler_X_full_cycle.transform(X_all)
    Y_scaled = scaler_Y_full_cycle.transform(Y_all)
    
    
    cycle_model = XGBRegressor(
        n_estimators=100,
        max_depth=10,
        learning_rate=0.12,
        subsample=0.86,
        colsample_bytree=0.8,
        # reg_alpha=0.8,
        # reg_lambda=20,
        random_state=42
    )
    cycle_model.fit(X_scaled, Y_scaled.ravel())  # ravel() 
    Y_test_pred_std = cycle_model.predict(X_scaled)
    Y_test_pred = scaler_Y_full_cycle.inverse_transform(Y_test_pred_std.reshape(-1, 1))
    
    # Y_test_pred = np.exp(Y_test_pred)
    # Y_all = np.exp(Y_all)
    # test_rmse = np.sqrt(mean_squared_error(Y_all, Y_test_pred))
    # print(f"Test RMSE: {test_rmse:.4f}")
    
    #%
    X_features_cycle = []
    Y_targets_cycle = []
    for i_idx, i in enumerate(test_batteries_indices):  #227, 38, 26, 100, 235 [7,216]
        efcs = all_efcs[i]
        # print('load data for cell:',all_cells[i][0])
        # C/5
        Cq_5 = np.array(all_Cq[i]) 
        Cp_5 = np.array(all_Cp_opt[i]) 
        Cn_5 = np.array(all_Cn_opt[i])
        Cli_5 = np.array(all_Cli[i])
        cell_ocv = all_cell_ocv[i]
        cell_Vreal = cell_ocv[:,:,0]*4.2
        cell_Qreal = cell_ocv[:,:,1]
        cell_vmea = all_cell_vmea[i]
        cell_Vmea = cell_vmea[:,:,0]*4.2
        cell_Qmea = cell_vmea[:,:,1]
        C40_pred = all_predictions_test[i]/norminal_c
        
        fit_OCV = np.array(all_OCV_fit[i])
        
        real_OCV = cell_Vmea
        
        if len(Cq_5)<3 : # 3--1  7--5
            continue
        for j in [1]:#range(len(Cq_5)):
            
           
            voc_fit = fit_OCV[j,:]
            voc_real = real_OCV[j,:]
            q_meas = cell_Qmea[j,:]
            v_meas = cell_Vmea[j,:]
            
            cp = Cp_5[j]
            cn = Cn_5[j]
            cq = Cq_5[j]
            cli = Cli_5[j]  #
            
            cq_40 = C40_pred[j, 0]
            cp_40 = C40_pred[j, 1] 
            cn_40 = C40_pred[j, 2] 
            cli_40 = C40_pred[j, 3]
            
            
           
            dv_dq_real = gradient( voc_real, q_meas*norminal_c )
            dv_dq_real = savgol_filter(dv_dq_real, window_length=91, polyorder=1)
            dv_dq_fit = gradient( voc_fit, q_meas*norminal_c )
            dv_dq_fit = savgol_filter(dv_dq_fit, window_length=91, polyorder=1)
            dv_dq_real0 = gradient( real_OCV[0,:], cell_Qmea[0,:]*norminal_c )
            dv_dq_real0 = savgol_filter(dv_dq_real0, window_length=91, polyorder=1)
            
            dv_dq_fit0 = gradient( fit_OCV[0,:], cell_Qmea[0,:]*norminal_c )
            dv_dq_fit0 = savgol_filter(dv_dq_fit0, window_length=91, polyorder=1)
            
            
            region_mask = (q_meas*norminal_c) < 3
            Q_focus = Q[region_mask]
            real_focus = -dv_dq_real[region_mask]
            fit_focus = -dv_dq_fit[region_mask]
            real_focus0 = -dv_dq_real0[region_mask]
            fit_focus0 = -dv_dq_fit0[region_mask]
            
            peaks_real, _ = find_peaks(real_focus, prominence=0.01)
            peaks_fit, _ = find_peaks(fit_focus, prominence=0.01)
            peaks_real0, _ = find_peaks(real_focus0, prominence=0.01)
            peaks_fit0, _ = find_peaks(fit_focus0, prominence=0.01)
            
            peaks_real_val= real_focus[peaks_real[0]]
            peaks_fit_val= fit_focus[peaks_fit[0]]
            peaks_real_val0= real_focus0[peaks_real0[0]]
            peaks_fit_val0= fit_focus0[peaks_fit0[0]]
            
            
            
            X_i = np.column_stack([
                
                cq,
                cq-Cq_5[0],
                cp,
                cp-Cp_5[0],
                cn,
                cn-Cn_5[0],
                cli,
                cli-Cli_5[0],
                
                # cq_40,
                # cq_40-C40_pred[0, 0],
                # cp_40,
                # cq_40-C40_pred[0, 1],
                # cn_40,
                # cq_40-C40_pred[0, 2],
                # cli_40,
                # cq_40-C40_pred[0, 3],
                
                np.var(voc_real-voc_fit),
                np.var(voc_real-real_OCV[0,:]),
                np.var(voc_fit-real_OCV[0,:]),
                np.max(abs(voc_real-voc_fit)),
                np.var(dv_dq_real[200:-200]-dv_dq_fit[200:-200]),
                np.var(dv_dq_real0[200:-200]-dv_dq_real[200:-200]),
                np.max(dv_dq_real[200:-200]-dv_dq_fit[200:-200]),
                np.min(dv_dq_real[200:-200]-dv_dq_fit[200:-200]),
                # efcs[j],
                peaks_real[0],
                peaks_fit[0],
                peaks_real_val,
                peaks_fit_val,
                abs(peaks_real_val0-peaks_real_val),
                abs(peaks_real0[0]-peaks_real[0]),
                abs(peaks_fit_val0-peaks_fit_val),
                abs(peaks_fit0[0]-peaks_fit[0]),
                efcs[j],
            ])
            
            
           #### flatten 
            X_features_cycle.append(X_i.flatten())  # shape: (6000,)
            Y_targets_cycle.append((np.log(all_cycle_life[i])))  # shape: (1000,)
            
            
    X_test_all = np.stack(X_features_cycle)  # shape: (n_samples, 6000)
    Y_test_all = np.stack(Y_targets_cycle).reshape(-1,1)   # shape: (n_samples, 1000)
    X_scaled = scaler_X_full_cycle.transform(X_test_all)
    Y_scaled = scaler_Y_full_cycle.transform(Y_test_all)
    # print(X_test_all.shape)
    print(np.mean(X_test_all[:,-1]),np.mean(X_test_all[:,0]))
    Y_test_pred_std = cycle_model.predict(X_scaled)
    Y_test_pred = scaler_Y_full_cycle.inverse_transform(Y_test_pred_std.reshape(-1,1))
    
    Y_test_pred = np.exp(Y_test_pred)
    Y_test_all = np.exp(Y_test_all)
    test_rmse = np.sqrt(mean_squared_error(Y_test_all, Y_test_pred))
    test_mae = mean_absolute_error(Y_test_all, Y_test_pred)
    test_pmae = np.mean(abs(Y_test_all-Y_test_pred)/Y_test_all)
    test_r2 = r2_score(Y_test_all.reshape(-1), Y_test_pred.reshape(-1))
    
    print(f"Test RMSE: {test_rmse:.4f}", f"Test MAE: {test_mae:.4f}", f"Test MPAE: {test_pmae:.4f}", f"Test R2: {test_r2:.4f}")
    
    Y_train_true = np.exp(Y_train_true)
    Y_train_pred = np.exp(Y_train_pred)
    
    Y_val_true = np.exp(Y_val_true)
    Y_val_pred = np.exp(Y_val_pred)
    
    train_rmse = np.sqrt(mean_squared_error(Y_train_true, Y_train_pred))
    train_mae = mean_absolute_error(Y_train_true, Y_train_pred)
    train_pmae = np.mean(abs(Y_train_true-Y_train_pred)/Y_train_true)
    train_r2 = r2_score(Y_train_true.reshape(-1), Y_train_pred.reshape(-1))
    
    print(f"Train RMSE: {train_rmse:.4f}", f"Train MAE: {train_mae:.4f}", f"Train MPAE: {train_pmae:.4f}", f"Train R2: {train_r2:.4f}")
    
    vali_rmse = np.sqrt(mean_squared_error(Y_val_true, Y_val_pred))
    vali_mae = mean_absolute_error(Y_val_true, Y_val_pred)
    vali_pmae = np.mean(abs(Y_val_true-Y_val_pred)/Y_val_true)
    vali_r2 = r2_score(Y_val_true.reshape(-1), Y_val_pred.reshape(-1))
    
    print(f"Vali RMSE: {vali_rmse:.4f}", f"Vali MAE: {vali_mae:.4f}", f"Vali MPAE: {vali_pmae:.4f}", f"Vali R2: {vali_r2:.4f}")
    
    
    Y_all_true = np.concatenate([Y_train_true, Y_val_true, Y_test_all])
    Y_all_pred = np.concatenate([Y_train_pred, Y_val_pred, Y_test_pred])
    print( r2_score(Y_all_true.reshape(-1), Y_all_pred.reshape(-1)))
    
    plt.figure(figsize=(6.5 / 2.54, 6 / 2.54), dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    plt.scatter(Y_train_true,Y_train_pred,color=colors[0],marker='s',linewidths=0.8,edgecolors='black', alpha=0.8, label='Train')
    plt.scatter(Y_val_true,Y_val_pred,color=colors[5],marker='^',linewidths=0.8, edgecolors='black',alpha=0.8, label='Validation 1')
    plt.scatter(Y_test_all,Y_test_pred,color=colors[1],marker='o',linewidths=0.8, edgecolors='black', alpha=0.8, label='Validation 2')
    # plt.plot(Y_all_true,Y_all_true,color='grey')
    plt.plot(np.arange(-20,1.1*max(Y_all_true)),np.arange(-20,1.1*max(Y_all_true)),'--',color='grey')
    plt.legend(loc='upper left',
        handletextpad=0.1, 
        labelspacing=0.05,
        frameon=False,
        bbox_to_anchor=(-0.05, 1.05),
        fontsize=10)
    plt.xlabel('Actual EFC')
    plt.xlim([-20,1.1*max(Y_all_true)])
    plt.ylim([-20,1.1*max(Y_all_true)])
              
    plt.ylabel('Predicted EFC')
    plt.show()
    
    explainer = shap.Explainer(cycle_model, X_scaled)
    shap_values = explainer(X_scaled)
    

    # shap.plots.bar(shap_values, max_display=20)  # 
    
    shap.plots.beeswarm(shap_values, max_display=20)
    
    mean_abs_shap = np.abs(shap_values.values).mean(axis=0)
    mean_abs_shap_all.append(mean_abs_shap)
    # # 
    # sorted_idx = np.argsort(mean_abs_shap)
    # sorted_shap = mean_abs_shap[sorted_idx]
    # sorted_names = np.array(feature_names)[sorted_idx]
    
    # # 
    # fig, ax = plt.subplots(figsize=(11/2.54, 8/2.54), dpi=600)
    # ax.barh(range(len(sorted_shap)), sorted_shap, color='#0073B1', edgecolor='black')
    # ax.set_yticks(range(len(sorted_names)))
    # ax.set_yticklabels(sorted_names)
    # ax.invert_yaxis()
    # plt.tight_layout()
    # plt.show()
    
#%%
top_k = len(coefs_nonzero)
sorted_idx = np.argsort(mean_abs_shap)[::-1] 
sorted_shap = mean_abs_shap[sorted_idx][:top_k][::-1]
sorted_names = np.array(feature_names)[sorted_idx][:top_k][::-1]

theta = np.linspace(0.0, 2 * np.pi, len(sorted_coefs), endpoint=False)
radii = sorted_coefs
width = 2 * np.pi / len(radii)
colors_bar = [cmap(i / (len(radii)-1)) for i in range(len(radii))]

# 
fig, ax = plt.subplots(subplot_kw=dict(polar=True), figsize=(6.5 / 2.54, 6 / 2.54), dpi=600)
ax.bar(
    x=theta,
    height=radii,
    width=width,
    color=colors_bar,  # 
    edgecolor='black',
    linewidth=0.5,
    alpha=0.9
)

ax.set_xticks(theta)
ax.set_xticklabels(sorted_names)  # 

ax.set_yticklabels([])
ax.set_facecolor("none")  # 
ax.grid(False)            # 
ax.spines['polar'].set_visible(False)  # 

ax.set_yticklabels([])  #
ax.set_xticks(theta)
ax.set_xticklabels([])
for angle, radius in zip(theta, radii):
    x_text = angle
    y_text = radius + max(radii)*0.12  # 

    ax.text(
        x_text, y_text,
        f"{radius:.2f}",
        ha='center',
        va='center',
        fontsize=8,
        rotation=np.degrees(angle),  # 
        rotation_mode='anchor'
    )
    
fig.tight_layout(pad=0)
plt.show()

theta = np.linspace(0.0, 2 * np.pi, top_k, endpoint=False)
radii = sorted_shap
width = 2 * np.pi / top_k
colors_bar = [cmap(i / (top_k - 1)) for i in range(top_k)]

fig, ax = plt.subplots(subplot_kw=dict(polar=True), figsize=(6.5 / 2.54, 6 / 2.54), dpi=600)
ax.bar(
    x=theta,
    height=radii,
    width=width,
    color=colors_bar,
    edgecolor='black',
    linewidth=0.5,
    alpha=0.9
)

# 
ax.set_xticks(theta)
ax.set_xticklabels([])  # 
ax.set_yticklabels([])
ax.set_facecolor("none")
ax.grid(False)
ax.spines['polar'].set_visible(False)

# 
for angle, radius in zip(theta, radii):
    ax.text(
        angle,
        radius + max(radii) * 0.12,
        f"{radius:.2f}",
        ha='center',
        va='center',
        fontsize=8,
        rotation=np.degrees(angle),
        rotation_mode='anchor'
    )

fig.tight_layout(pad=0)
plt.show()

fig, ax = plt.subplots(figsize=(11 / 2.54, 8 / 2.54), dpi=600)
y_pos = np.arange(top_k)

ax.barh(
    y=y_pos,
    width=radii,
    height=0.65,
    color=colors_bar,
    edgecolor='black',
    linewidth=0.5,
    alpha=0.9
)
ax.set_yticks(y_pos)
ax.set_yticklabels(sorted_names)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.invert_yaxis()
plt.tight_layout()
plt.show()


fig, ax = plt.subplots( figsize=(24 / 2.54,4 / 2.54), dpi=600)
sns.heatmap(np.array(np.array(mean_abs_shap_all)[::-1]), cmap=cmap, alpha=0.9, cbar_kws={'label': 'Feature contribution'})
ax.set_xticklabels(feature_names, rotation=90)
# ax.set_yticklabels(['EFC @ 95% SOH','EFC @ 90% SOH','EFC @ 85% SOH','EFC @ 82.5% SOH','EFC @ 80% SOH'], rotation=0)
ax.set_yticklabels(['EFC @ 80% SOH','EFC @ 82.5% SOH','EFC @ 85% SOH','EFC @ 90% SOH','EFC @ 95% SOH'], rotation=0)
plt.show()